Skip to content
Closed
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a95bc22
timing for DecisionTree internals
jkbradley Aug 5, 2014
511ec85
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 6, 2014
bcf874a
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 7, 2014
f61e9d2
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 8, 2014
3211f02
Optimizing DecisionTree
jkbradley Aug 8, 2014
0f676e2
Optimizations + Bug fix for DecisionTree
jkbradley Aug 8, 2014
b2ed1f3
Merge remote-tracking branch 'upstream/master' into dt-opt
jkbradley Aug 8, 2014
b914f3b
DecisionTree optimization: eliminated filters + small changes
jkbradley Aug 9, 2014
c1565a5
Small DecisionTree updates:
jkbradley Aug 11, 2014
a87e08f
Merge remote-tracking branch 'upstream/master' into dt-opt1
jkbradley Aug 14, 2014
8464a6e
Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. …
jkbradley Aug 14, 2014
e66f1b1
TreePoint
jkbradley Aug 14, 2014
d036089
Print timing info to logDebug.
jkbradley Aug 14, 2014
430d782
Added more debug info on binning error. Added some docs.
jkbradley Aug 14, 2014
356daba
Merge branch 'dt-opt1' into dt-opt2
jkbradley Aug 14, 2014
26d10dd
Removed tree/model/Filter.scala since no longer used. Removed debugg…
jkbradley Aug 15, 2014
2d2aaaf
Merge remote-tracking branch 'upstream/master' into dt-opt1
jkbradley Aug 15, 2014
6b5651e
Updates based on code review. 1 major change: persisting to memory +…
jkbradley Aug 15, 2014
5f2dec2
Fixed scalastyle issue in TreePoint
jkbradley Aug 15, 2014
f40381c
Merge branch 'dt-opt1' into dt-opt2
jkbradley Aug 15, 2014
797f68a
Fixed DecisionTreeSuite bug for training second level. Needed to upd…
jkbradley Aug 15, 2014
931a3a7
Merge remote-tracking branch 'upstream/master' into dt-opt2
jkbradley Aug 15, 2014
6a38f48
Added DTMetadata class for cleaner code
jkbradley Aug 16, 2014
db0d773
scala style fix
jkbradley Aug 16, 2014
ac0b9f8
Small updates based on code review.
jkbradley Aug 16, 2014
3726d20
Small code improvements based on code review.
jkbradley Aug 17, 2014
a0ed0da
Renamed DTMetadata to DecisionTreeMetadata. Small doc updates.
jkbradley Aug 17, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
877 changes: 387 additions & 490 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.impl

import scala.collection.mutable

import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impurity.Impurity
import org.apache.spark.rdd.RDD


/**
* Learning and dataset metadata for DecisionTree.
*
* @param featureArity Map: categorical feature index --> arity.
* I.e., the feature takes values in {0, ..., arity - 1}.
*/
private[tree] class DTMetadata(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: DecisionTreeMetadata, TreeMetadata or Metadata as class name alternatives.

val numFeatures: Int,
val numExamples: Long,
val numClasses: Int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful if we add doc for numClasses mentioning its value for regression problems.

val maxBins: Int,
val featureArity: Map[Int, Int],
val unorderedFeatures: Set[Int],
val impurity: Impurity,
val quantileStrategy: QuantileStrategy) extends Serializable {

def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)

def isClassification: Boolean = numClasses >= 2

def isMulticlass: Boolean = numClasses > 2

def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0)

def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex)

def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)

}

private[tree] object DTMetadata {

def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DTMetadata = {

val numFeatures = input.take(1)(0).features.size
val numExamples = input.count()
val numClasses = strategy.algo match {
case Classification => strategy.numClassesForClassification
case Regression => 0
}

val maxBins = math.min(strategy.maxBins, numExamples).toInt

val unorderedFeatures = new mutable.HashSet[Int]()
if (numClasses > 2) {
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
val numUnorderedBins = (1 << k - 1) - 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check the value of k first. If k > 30, the result will be unexpected. Using 1L instead of 1 may help.

if (numUnorderedBins < maxBins) {
unorderedFeatures.add(f)
} else {
// TODO: Allow this case, where we simply will know nothing about some categories?
require(k < maxBins, "maxBins should be greater than max categories " +
"in categorical features")
}
}
} else {
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we allow maxBins >= k in binary classification?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code actually relies on this since it uses one more bin than it needs to. The correction is in my next code update, sorry.

require(k < maxBins, "maxBins should be greater than max categories " +
"in categorical features")
}
}

new DTMetadata(numFeatures, numExamples, numClasses, maxBins,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
strategy.impurity, strategy.quantileCalculationStrategy)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.mllib.tree.impl

import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.model.Bin
import org.apache.spark.rdd.RDD

Expand Down Expand Up @@ -48,50 +47,35 @@ private[tree] object TreePoint {
* Convert an input dataset into its TreePoint representation,
* binning feature values in preparation for DecisionTree training.
* @param input Input dataset.
* @param strategy DecisionTree training info, used for dataset metadata.
* @param bins Bins for features, of size (numFeatures, numBins).
* @param metadata Learning and dataset metadata
* @return TreePoint dataset representation
*/
def convertToTreeRDD(
input: RDD[LabeledPoint],
strategy: Strategy,
bins: Array[Array[Bin]]): RDD[TreePoint] = {
bins: Array[Array[Bin]],
metadata: DTMetadata): RDD[TreePoint] = {
input.map { x =>
TreePoint.labeledPointToTreePoint(x, strategy.isMulticlassClassification, bins,
strategy.categoricalFeaturesInfo)
TreePoint.labeledPointToTreePoint(x, bins, metadata)
}
}

/**
* Convert one LabeledPoint into its TreePoint representation.
* @param bins Bins for features, of size (numFeatures, numBins).
* @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity
*/
private def labeledPointToTreePoint(
labeledPoint: LabeledPoint,
isMulticlassClassification: Boolean,
bins: Array[Array[Bin]],
categoricalFeaturesInfo: Map[Int, Int]): TreePoint = {
metadata: DTMetadata): TreePoint = {

val numFeatures = labeledPoint.features.size
val numBins = bins(0).size
val arr = new Array[Int](numFeatures)
var featureIndex = 0
while (featureIndex < numFeatures) {
val featureInfo = categoricalFeaturesInfo.get(featureIndex)
val isFeatureContinuous = featureInfo.isEmpty
if (isFeatureContinuous) {
arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false,
bins, categoricalFeaturesInfo)
} else {
val featureCategories = featureInfo.get
val isSpaceSufficientForAllCategoricalSplits
= numBins > math.pow(2, featureCategories.toInt - 1) - 1
val isUnorderedFeature =
isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous,
isUnorderedFeature, bins, categoricalFeaturesInfo)
}
arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex),
metadata.isUnordered(featureIndex), bins, metadata.featureArity)
featureIndex += 1
}

Expand Down
18 changes: 14 additions & 4 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,25 @@ package org.apache.spark.mllib.tree.model
import org.apache.spark.mllib.tree.configuration.FeatureType._

/**
* Used for "binning" the features bins for faster best split calculation. For a continuous
* feature, a bin is determined by a low and a high "split". For a categorical feature,
* the a bin is determined using a single label value (category).
* Used for "binning" the features bins for faster best split calculation.
*
* For a continuous feature, the bin is determined by a low and a high split,
* where an example with featureValue falls into the bin s.t.
* lowSplit.threshold < featureValue <= highSplit.threshold.
*
* For ordered categorical features, there is a 1-1-1 correspondence between
* bins, splits, and feature values. The bin is determined by category/feature value.
* However, the bins are not necessarily ordered by feature value;
* they are ordered using impurity.
* For unordered categorical features, there is a 1-1 correspondence between bins, splits,
* where bins and splits correspond to subsets of feature values (in highSplit.categories).
*
* @param lowSplit signifying the lower threshold for the continuous feature to be
* accepted in the bin
* @param highSplit signifying the upper threshold for the continuous feature to be
* accepted in the bin
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin for binary classification
* @param category categorical label value accepted in the bin for ordered features
*/
private[tree]
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
* @return Double prediction from the trained model
*/
def predict(features: Vector): Double = {
topNode.predictIfLeaf(features)
topNode.predict(features)
}

/**
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,24 @@ class Node (

/**
* predict value if node is not leaf
* @param feature feature value
* @param features feature value
* @return predicted value
*/
def predictIfLeaf(feature: Vector) : Double = {
def predict(features: Vector) : Double = {
if (isLeaf) {
predict
} else{
if (split.get.featureType == Continuous) {
if (feature(split.get.feature) <= split.get.threshold) {
leftNode.get.predictIfLeaf(feature)
if (features(split.get.feature) <= split.get.threshold) {
leftNode.get.predict(features)
} else {
rightNode.get.predictIfLeaf(feature)
rightNode.get.predict(features)
}
} else {
if (split.get.categories.contains(feature(split.get.feature))) {
leftNode.get.predictIfLeaf(feature)
if (split.get.categories.contains(features(split.get.feature))) {
leftNode.get.predict(features)
} else {
rightNode.get.predictIfLeaf(feature)
rightNode.get.predict(features)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
* :: DeveloperApi ::
* Split applied to a feature
* @param feature feature index
* @param threshold threshold for continuous feature
* @param threshold Threshold for continuous feature.
* Split left if feature <= threshold, else right.
* @param featureType type of feature -- categorical or continuous
* @param categories accepted values for categorical variables
* @param categories Split left if categorical feature value is in this set, else right.
*/
@DeveloperApi
case class Split(
Expand Down
Loading