-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tree #3
base: master
Are you sure you want to change the base?
Tree #3
Changes from all commits
1eba6f3
e2231ad
95d45ab
6754a40
49b8797
376e241
35cebe6
9ad0dd5
fc773de
e29493c
6047ed8
2a1185b
b2447a8
68ad6c8
729a3b1
d956cd3
422ed7d
020069a
271d1f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,3 +15,7 @@ project/plugins/project/ | |
#Eclipse specific | ||
.classpath | ||
.project | ||
|
||
#IDEA specific | ||
.idea | ||
.idea_modules |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package ml.tree | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should refactor this as mli.ml.tree |
||
import javax.naming.OperationNotSupportedException | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.SparkContext._ | ||
import org.apache.spark.mllib.regression.LabeledPoint | ||
import org.apache.spark.mllib.classification.ClassificationModel | ||
import org.apache.spark.SparkContext | ||
import org.apache.spark.util.StatCounter | ||
import org.apache.spark.Logging | ||
import ml.tree.impurity.{Variance, Entropy, Gini, Impurity} | ||
import ml.tree.strategy.Strategy | ||
import ml.tree.split.{SplitPredicate, Split} | ||
import org.apache.spark.broadcast.Broadcast | ||
import scala.Some | ||
import ml.tree.strategy.Strategy | ||
import ml.tree.split.Split | ||
import ml.tree.node._ | ||
import ml.tree.Metrics._ | ||
import scala.Some | ||
import ml.tree.strategy.Strategy | ||
import ml.tree.split.Split | ||
|
||
|
||
/* | ||
* Class for building the Decision Tree model. Should be used for both classification and regression tree. | ||
*/ | ||
class DecisionTree ( | ||
val input: RDD[(Double, Array[Double])], //input RDD | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be great if this could take an MLNumericTable as input with a "targetCol" that indicates which column to use as response. This may impact some design choices later (e.g. enforces a row-oriented perspective). |
||
val maxDepth: Int, // depth of the tree | ||
val numSplitPredicates: Int, // number of bins per features | ||
val fraction: Double, // fraction of the data to be used for performing quantile calculation | ||
val strategy: Strategy, // classification or regression | ||
val impurity: Impurity, // impurity calculation strategy (variance, gini, entropy, etc.) | ||
val sparkContext : SparkContext) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto for MLContext here. |
||
|
||
//Calculating length of the features | ||
val featureLength = input.first._2.length | ||
println("feature length = " + featureLength) | ||
|
||
//Sampling a fraction of the input RDD | ||
val sampledData = input.sample(false, fraction, 42).cache() | ||
println("sampled data size for quantile calculation = " + sampledData.count) | ||
|
||
//Sorting the sampled data along each feature and storing it for quantile calculation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're doing exact quantiling here, which is good. The downside here is it requires a sort on each data column, which is potentially expensive. I wonder how bad it would be to approximate with low-high and split values at intervals of (low-high)/#bins - this way we'd effectively have fixed width histograms in one pass over the data (just collect max and min for each feature). This is pretty outlier sensitive, but maybe worth thinking about. Finally, I'm wondering how effective this strategy will be on low-variance or highly sparse columns where 90% of the data has the same value. In that case, your conversion of split predicates to sets effectively drops the #of split candidates for a particular column down to close to 0 (or even to 0), even though there might be some signal there. I've never used Decision Trees on data with that kind of skew in practice, but I imagine it's common in e.g. online marketing. |
||
println("started sorting sampled data") | ||
val sortedSampledFeatures = { | ||
val sortedFeatureArray = new Array[Array[Double]](featureLength) | ||
0 until featureLength foreach { | ||
i => sortedFeatureArray(i) = sampledData.map(x => x._2(i) -> None).sortByKey(true).map(_._1).collect() | ||
} | ||
sortedFeatureArray | ||
} | ||
println("finished sorting sampled data") | ||
|
||
val numSamples = sampledData.count | ||
println("num samples = " + numSamples) | ||
|
||
// Calculating the index to jump to find the quantile points | ||
val stride = scala.math.max(numSamples / numSplitPredicates, 1) | ||
println("stride = " + stride) | ||
|
||
//Calculating all possible splits for the features | ||
println("calculating all possible splits for features") | ||
val allSplitsList = for { | ||
featureIndex <- 0 until featureLength; | ||
index <- stride until numSamples - 1 by stride | ||
} yield createSplit(featureIndex, index) | ||
println("finished calculating all possible splits for features") | ||
|
||
//Remove duplicate splits. Especially help for one-hot encoded categorical variables. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the current setup handle categorical variables? It seems like we're treating them just like ordered numeric variables now? |
||
val allSplits = sparkContext.broadcast(allSplitsList.toSet) | ||
|
||
//for (split <- allSplits) yield println(split) | ||
|
||
/* | ||
* Find the exact value using feature index and index into the sorted features | ||
*/ | ||
def valueAtRDDIndex(featuresIndex: Long, index: Long): Double = { | ||
sortedSampledFeatures(featuresIndex.toInt)(index.toInt) | ||
} | ||
|
||
/* | ||
* Create splits using feature index and index into the sorted features | ||
*/ | ||
def createSplit(featureIndex: Int, index: Long): Split = { | ||
new Split(featureIndex, valueAtRDDIndex(featureIndex, index)) | ||
} | ||
|
||
def buildTree(): Node = { | ||
|
||
println("building decision tree") | ||
|
||
strategy match { | ||
case Strategy("Classification") => new TopClassificationNode(input, allSplits, impurity, strategy, maxDepth) | ||
case Strategy("Regression") => { | ||
val count = input.count | ||
//TODO: calculate mean and variance together | ||
val variance = input.map(x => x._1).variance | ||
val mean = input.map(x => x._1).mean | ||
val nodeStats = new NodeStats(count = Some(count), variance = Some(variance), mean = Some(mean)) | ||
new TopRegressionNode(input, nodeStats,allSplits, impurity, strategy, maxDepth) | ||
} | ||
} | ||
} | ||
|
||
} | ||
|
||
|
||
object DecisionTree { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For consistency with the rest of the MLI codebase, we might want wrapper objects like RegressionTreeModel and RegressionTreeAlgorithm as well as DecisionTreeModel/DecisionTreeAlgorithm and put those in mli.ml.regression and classification respectively. They'd be thin wrappers around DecisionTree.train() (and the common 'tree' module is fine to live in its own area!) |
||
def train( | ||
input: RDD[(Double, Array[Double])], | ||
numSplitPredicates: Int, | ||
strategy: Strategy, | ||
impurity: Impurity, | ||
maxDepth : Int, | ||
fraction : Double, | ||
sparkContext : SparkContext): Option[NodeModel] = { | ||
val tree = new DecisionTree( | ||
input = input, | ||
numSplitPredicates = numSplitPredicates, | ||
strategy = strategy, | ||
impurity = impurity, | ||
maxDepth = maxDepth, | ||
fraction = fraction, | ||
sparkContext = sparkContext) | ||
.buildTree | ||
.extractModel | ||
|
||
println("calculating performance on training data") | ||
val trainingError = { | ||
strategy match { | ||
case Strategy("Classification") => accuracyScore(tree, input) | ||
case Strategy("Regression") => meanSquaredError(tree, input) | ||
} | ||
} | ||
println("accuracy = " + trainingError) | ||
|
||
tree | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package ml.tree | ||
|
||
import org.apache.spark.SparkContext._ | ||
import ml.tree.node.NodeModel | ||
import org.apache.spark.rdd.RDD | ||
|
||
/* | ||
Helper methods for measuring performance of ML algorithms | ||
*/ | ||
object Metrics { | ||
|
||
//TODO: Make these generic MLTable metrics. | ||
def accuracyScore(tree : Option[NodeModel], data : RDD[(Double, Array[Double])]) : Double = { | ||
if (tree.isEmpty) return 1 //TODO: Throw exception | ||
val correctCount = data.filter(y => tree.get.predict(y._2) == y._1).count() | ||
val count = data.count() | ||
println("correct prediction count = " + correctCount) | ||
println("data count = " + count) | ||
correctCount.toDouble / count | ||
} | ||
|
||
//TODO: Make these generic MLTable metrics | ||
def meanSquaredError(tree : Option[NodeModel], data : RDD[(Double, Array[Double])]) : Double = { | ||
if (tree.isEmpty) return 1 //TODO: Throw exception | ||
val meanSumOfSquares = data.map(y => (tree.get.predict(y._2) - y._1)*(tree.get.predict(y._2) - y._1)).mean() | ||
println("meanSumOfSquares = " + meanSumOfSquares) | ||
meanSumOfSquares | ||
} | ||
|
||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#Decision Tree | ||
Decision tree classifiers are both popular supervised learning algorithms and also building blocks for other ensemble learning algorithms such as random forests, boosting, etc. This document discusses its implementation in the Spark project. | ||
|
||
#Usage | ||
``` | ||
ml.tree.TreeRunner | ||
<master>[slices] | ||
--strategy <Classification,Regression> | ||
--trainDataDir path | ||
--testDataDir path | ||
[--maxDepth num] | ||
[--impurity <Gini,Entropy,Variance>] | ||
[--samplingFractionForSplitCalculation num] | ||
``` | ||
|
||
#Example | ||
``` | ||
sbt/sbt "run-main ml.tree.TreeRunner local[2] --strategy Classification | ||
--trainDataDir ../train_data --testDataDir ../test_data | ||
--maxDepth 1 --impurity Gini --samplingFractionForSplitCalculation 1 | ||
``` | ||
|
||
This command will create a decision tree model using the training data in the *trainDataDir* and calculate test error using the data in the *testDataDir*. The mis-classification error is calculated for a Classification *strategy* and mean squared error is calculated for the Regression *strategy*. | ||
|
||
#Performance testing | ||
To be done | ||
|
||
#Improvements | ||
* Print to dot files | ||
* Unit tests | ||
* Change fractions to quantiles | ||
* Add logging | ||
* Move metrics to a different package | ||
|
||
#Extensions | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
* Extremely randomized trees | ||
* Random forest | ||
* Boosting |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
package ml.tree | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a nice program to have - we should really think about how we want to standardize on command line utils. I think a single driver program that can instantiate models/write models/results out to disk would be really valuable and save a bunch of duplicated effort. For now this is great though! |
||
|
||
import org.apache.spark.SparkContext._ | ||
import org.apache.spark.{Logging, SparkContext} | ||
import ml.tree.impurity.{Variance, Entropy, Gini} | ||
import ml.tree.strategy.Strategy | ||
|
||
import ml.tree.node.NodeModel | ||
import org.apache.spark.rdd.RDD | ||
|
||
import ml.tree.Metrics.{accuracyScore,meanSquaredError} | ||
|
||
object TreeRunner extends Logging { | ||
val usage = """ | ||
Usage: TreeRunner <master>[slices] --strategy <Classification,Regression> --trainDataDir path --testDataDir path [--maxDepth num] [--impurity <Gini,Entropy,Variance>] [--samplingFractionForSplitCalculation num] | ||
""" | ||
|
||
def main(args: Array[String]) { | ||
|
||
if (args.length < 2) { | ||
System.err.println(usage) | ||
System.exit(1) | ||
} | ||
|
||
/**START Experimental*/ | ||
System.setProperty("spark.cores.max", "8") | ||
/**END Experimental*/ | ||
val sc = new SparkContext(args(0), "Decision Tree Runner", | ||
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) | ||
|
||
|
||
val arglist = args.toList.drop(1) | ||
type OptionMap = Map[Symbol, Any] | ||
|
||
def nextOption(map : OptionMap, list: List[String]) : OptionMap = { | ||
def isSwitch(s : String) = (s(0) == '-') | ||
list match { | ||
case Nil => map | ||
case "--strategy" :: string :: tail => nextOption(map ++ Map('strategy -> string), tail) | ||
case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string), tail) | ||
case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), tail) | ||
case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) | ||
case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) | ||
case "--samplingFractionForSplitCalculation" :: string :: tail => nextOption(map ++ Map('samplingFractionForSplitCalculation -> string), tail) | ||
case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) | ||
case option :: tail => println("Unknown option "+option) | ||
exit(1) | ||
} | ||
} | ||
val options = nextOption(Map(),arglist) | ||
println(options) | ||
//TODO: Add check for acceptable string inputs | ||
|
||
val trainData = TreeUtils.loadLabeledData(sc, options.get('trainDataDir).get.toString) | ||
val strategyStr = options.get('strategy).get.toString | ||
val impurityStr = options.getOrElse('impurity,"Gini").toString | ||
val impurity = { | ||
impurityStr match { | ||
case "Gini" => Gini | ||
case "Entropy" => Entropy | ||
case "Variance" => Variance | ||
} | ||
} | ||
val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt | ||
val fraction = options.getOrElse('samplingFractionForSplitCalculation,"1.0").toString.toDouble | ||
|
||
val tree = DecisionTree.train( | ||
input = trainData, | ||
numSplitPredicates = 1000, | ||
strategy = new Strategy(strategyStr), | ||
impurity = impurity, | ||
maxDepth = maxDepth, | ||
fraction = fraction, | ||
sparkContext = sc) | ||
println(tree) | ||
//println("prediction = " + tree.get.predict(Array(1.0, 2.0))) | ||
|
||
println("loading test data") | ||
val testData = TreeUtils.loadLabeledData(sc, options.get('testDataDir).get.toString) | ||
|
||
println("calculating performance of test data") | ||
val testError = { | ||
strategyStr match { | ||
case "Classification" => accuracyScore(tree, testData) | ||
case "Regression" => meanSquaredError(tree, testData) | ||
} | ||
} | ||
println("accuracy = " + testError) | ||
|
||
} | ||
|
||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
package ml.tree | ||
|
||
import org.apache.spark.SparkContext | ||
import org.apache.spark.rdd.RDD | ||
|
||
|
||
//TODO: Deprecate this when we find something equivalent in ml utils | ||
/** | ||
* Helper methods to load and save data | ||
* Data format: | ||
* <l>, <f1> <f2> ... | ||
* where <f1>, <f2> are feature values in Double and <l> is the corresponding label as Double. | ||
*/ | ||
object TreeUtils { | ||
|
||
/** | ||
* @param sc SparkContext | ||
* @param dir Directory to the input data files. | ||
* @return An RDD of tuples. For each tuple, the first element is the label, and the second | ||
* element represents the feature values (an array of Double). | ||
*/ | ||
def loadLabeledData(sc: SparkContext, dir: String): RDD[(Double, Array[Double])] = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably just be using the standard MLContext.loadCsvFile here (though it should probably generate an MLNumericTable). |
||
sc.textFile(dir).map { line => | ||
val parts = line.trim().split(",") | ||
val label = parts(0).toDouble | ||
val features = parts.slice(1,parts.length).map(_.toDouble) | ||
//val features = parts.slice(1, 30).map(_.toDouble) | ||
(label, features) | ||
} | ||
} | ||
|
||
def saveLabeledData(data: RDD[(Double, Array[Double])], dir: String) { | ||
val dataStr = data.map(x => x._1 + "," + x._2.mkString(" ")) | ||
dataStr.saveAsTextFile(dir) | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good idea, thanks.