Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, lit, udf}
import org.apache.spark.sql.types.DoubleType

/**
Expand Down Expand Up @@ -136,8 +136,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
val strategy = getOldStrategy(categoricalFeatures, numClasses)
instr.logNumClasses(numClasses)
instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,
probabilityCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
cacheNodeIds, checkpointInterval, impurity, seed)
probabilityCol, leafCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain,
maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed)

val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
Expand Down Expand Up @@ -210,6 +210,18 @@ class DecisionTreeClassificationModel private[ml] (
rootNode.predictImpl(features).prediction
}

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)

val outputData = super.transform(dataset)
if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => predictLeaf(features) }
Comment thread
srowen marked this conversation as resolved.
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))))
} else {
outputData
}
}

override protected def predictRaw(features: Vector): Vector = {
Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._

/**
Expand Down Expand Up @@ -191,8 +191,8 @@ class GBTClassifier @Since("1.4.0") (

instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, lossType,
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, impurity,
lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy,
validationIndicatorCol, validationTol)
instr.logNumClasses(numClasses)
Expand Down Expand Up @@ -286,6 +286,18 @@ class GBTClassificationModel private[ml](
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)

val outputData = super.transform(dataset)
if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => predictLeaf(features) }
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))))
} else {
outputData
}
}

override def predict(features: Vector): Double = {
// If thresholds defined, use predictRaw to get probabilities, otherwise use optimization
if (isDefined(thresholds)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
/**
* (private[classification]) Params for probabilistic classification.
*/
private[classification] trait ProbabilisticClassifierParams
private[ml] trait ProbabilisticClassifierParams
extends ClassifierParams with HasProbabilityCol with HasThresholds {
override protected def validateAndTransformSchema(
schema: StructType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}

/**
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
Expand Down Expand Up @@ -135,8 +136,9 @@ class RandomForestClassifier @Since("1.4.0") (
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)

instr.logParams(this, labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol,
impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)
leafCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB,
minInfoGain, minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds,
checkpointInterval)

val trees = RandomForest
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
Expand Down Expand Up @@ -207,6 +209,18 @@ class RandomForestClassificationModel private[ml] (
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)

val outputData = super.transform(dataset)
if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => predictLeaf(features) }
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))))
} else {
outputData
}
}

override protected def predictRaw(features: Vector): Vector = {
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
// Classifies using majority votes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,25 +212,28 @@ class DecisionTreeRegressionModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
transformImpl(dataset)
}

override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
var predictionColNames = Seq.empty[String]
var predictionColumns = Seq.empty[Column]

if ($(predictionCol).nonEmpty) {
val predictUDF = udf { (features: Vector) => predict(features) }
val predictUDF = udf { features: Vector => predict(features) }
predictionColNames :+= $(predictionCol)
predictionColumns :+= predictUDF(col($(featuresCol)))
}

if (isDefined(varianceCol) && $(varianceCol).nonEmpty) {
val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) }
val predictVarianceUDF = udf { features: Vector => predictVariance(features) }
predictionColNames :+= $(varianceCol)
predictionColumns :+= predictVarianceUDF(col($(featuresCol)))
}

if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => predictLeaf(features) }
predictionColNames :+= $(leafCol)
predictionColumns :+= leafUDF(col($(featuresCol)))
}

if (predictionColNames.nonEmpty) {
dataset.withColumns(predictionColNames, predictionColumns)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._

/**
Expand Down Expand Up @@ -169,7 +169,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)

instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, lossType,
instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, impurity, lossType,
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy,
validationIndicatorCol, validationTol)
Expand Down Expand Up @@ -245,12 +245,33 @@ class GBTRegressionModel private[ml](
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)

var predictionColNames = Seq.empty[String]
var predictionColumns = Seq.empty[Column]

val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])

if ($(predictionCol).nonEmpty) {
val predictUDF = udf { features: Vector => bcastModel.value.predict(features) }
predictionColNames :+= $(predictionCol)
predictionColumns :+= predictUDF(col($(featuresCol)))
}

if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => bcastModel.value.predictLeaf(features) }
predictionColNames :+= $(leafCol)
predictionColumns :+= leafUDF(col($(featuresCol)))
}

if (predictionColNames.nonEmpty) {
dataset.withColumns(predictionColNames, predictionColumns)
} else {
this.logWarning(s"$uid: GBTRegressionModel.transform() does nothing" +
" because no output columns were set.")
dataset.toDF()
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

override def predict(features: Vector): Double = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.{Column, DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}

/**
Expand Down Expand Up @@ -123,7 +123,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S

instr.logPipelineStage(this)
instr.logDataset(instances)
instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, numTrees,
instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, impurity, numTrees,
featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval)

Expand Down Expand Up @@ -191,12 +191,33 @@ class RandomForestRegressionModel private[ml] (
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)

var predictionColNames = Seq.empty[String]
var predictionColumns = Seq.empty[Column]

val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])

if ($(predictionCol).nonEmpty) {
val predictUDF = udf { features: Vector => bcastModel.value.predict(features) }
predictionColNames :+= $(predictionCol)
predictionColumns :+= predictUDF(col($(featuresCol)))
}

if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => bcastModel.value.predictLeaf(features) }
predictionColNames :+= $(leafCol)
predictionColumns :+= leafUDF(col($(featuresCol)))
}

if (predictionColNames.nonEmpty) {
dataset.withColumns(predictionColNames, predictionColumns)
} else {
this.logWarning(s"$uid: RandomForestRegressionModel.transform() does nothing" +
" because no output columns were set.")
dataset.toDF()
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

override def predict(features: Vector): Double = {
Expand Down
38 changes: 33 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ import org.apache.spark.util.collection.OpenHashMap

/**
* Abstraction for Decision Tree models.
*
* TODO: Add support for predicting probabilities and raw predictions SPARK-3727
*/
private[spark] trait DecisionTreeModel {

Expand Down Expand Up @@ -78,13 +76,34 @@ private[spark] trait DecisionTreeModel {

/** Convert to spark.mllib DecisionTreeModel (losing some information) */
private[spark] def toOld: OldDecisionTreeModel

/**
* @return an iterator that traverses (DFS, left to right) the leaves
* in the subtree of this node.
*/
private def leafIterator(node: Node): Iterator[LeafNode] = {
node match {
case l: LeafNode => Iterator.single(l)
case n: InternalNode =>
leafIterator(n.leftChild) ++ leafIterator(n.rightChild)
}
}

@transient private lazy val leafIndices: Map[LeafNode, Int] = {
Comment thread
srowen marked this conversation as resolved.
leafIterator(rootNode).zipWithIndex.toMap
}

/**
* @return The index of the leaf corresponding to the feature vector.
* Leaves are indexed in pre-order from 0.
*/
def predictLeaf(features: Vector): Double = {
leafIndices(rootNode.predictImpl(features)).toDouble
}
}

/**
* Abstraction for models which are ensembles of decision trees
*
* TODO: Add support for predicting probabilities and raw predictions SPARK-3727
*
* @tparam M Type of tree model in this ensemble
*/
private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] {
Expand Down Expand Up @@ -118,6 +137,15 @@ private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] {

/** Total number of nodes, summed over all trees in the ensemble. */
lazy val totalNumNodes: Int = trees.map(_.numNodes).sum

/**
* @return The indices of the leaves corresponding to the feature vector.
* Leaves are indexed in pre-order from 0.
*/
def predictLeaf(features: Vector): Vector = {
val indices = trees.map(_.predictLeaf(features))
Vectors.dense(indices)
}
}

private[ml] object TreeEnsembleModel {
Expand Down
Loading