From 1eba6f31c31d91cb7a9b2167c01fdc51005ee9e6 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Thu, 26 Sep 2013 21:32:24 -0700 Subject: [PATCH 01/19] migrating tree code to MLI --- src/main/scala/ml/tree/DecisionTree.scala | 613 ++++++++++++++++++++++ 1 file changed, 613 insertions(+) create mode 100644 src/main/scala/ml/tree/DecisionTree.scala diff --git a/src/main/scala/ml/tree/DecisionTree.scala b/src/main/scala/ml/tree/DecisionTree.scala new file mode 100644 index 0000000..6c51795 --- /dev/null +++ b/src/main/scala/ml/tree/DecisionTree.scala @@ -0,0 +1,613 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ml.tree +import javax.naming.OperationNotSupportedException +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.classification.ClassificationModel +import org.apache.spark.SparkContext +import org.apache.spark.util.StatCounter + + +/* + * Abstract Node class as a template for implementing various types of nodes in the decision tree. + */ +abstract class Node { + //Method for checking whether the class has any left/right child nodes. + def isLeaf: Boolean + //Left/Right child nodes + def left: Node + def right: Node + //Depth of the node from the top node + def depth: Int + //RDD data as an input to the node + def data: RDD[(Double, Array[Double])] + //List of split predicates applied to the base RDD thus far + def splitPredicates: List[SplitPredicate] + // Split to arrive at the node + def splitPredicate: Option[SplitPredicate] + //Extract model + def extractModel: Option[NodeModel] = { + //Add probability logic + if (!splitPredicate.isEmpty) { Some(new NodeModel(splitPredicate, left.extractModel, right.extractModel, depth, isLeaf, Some(prediction))) } + else { + // Using -1 as depth + Some(new NodeModel(None, None, None, depth, isLeaf, Some(prediction))) + } + } + def prediction: Prediction +} + +/** + * The decision tree model class that + */ +class NodeModel( + val splitPredicate: Option[SplitPredicate], + val trueNode: Option[NodeModel], + val falseNode: Option[NodeModel], + val depth: Int, + val isLeaf: Boolean, + val prediction: Option[Prediction]) extends ClassificationModel { + + override def toString() = if (!splitPredicate.isEmpty) { + "[" + trueNode.get + "\n" + "[" + "depth = " + depth + ", split predicate = " + this.splitPredicate.get + ", predict = " + this.prediction + "]" + "]\n" + falseNode.get + } else { + "Leaf : " + "depth = " + depth + ", predict = " + prediction + ", isLeaf = " + isLeaf + } + + /** + * Predict values for the given data set using the model trained. + * + * @param testData RDD representing data points to be predicted + * @return RDD[Int] where each entry contains the corresponding prediction + */ + def predict(testData: RDD[Array[Double]]): RDD[Double] = { + testData.map { x => predict(x) } + } + + /** + * Predict values for a single data point using the model trained. + * + * @param testData array representing a single data point + * @return Int prediction from the trained model + */ + def predict(testData: Array[Double]): Double = { + //TODO: Modify this logic to handle regression + val pred = prediction.get + if (this.isLeaf) { + if (pred.prob > 0.5) 1 else 0 + } else { + val spPred = splitPredicate.get + if (testData(spPred.split.feature) <= spPred.split.threshold) { + trueNode.get.predict(testData) + } else { + falseNode.get.predict(testData) + } + } + } + +} + +/* + * Class used to store the prediction values at each node of the tree. + */ +class Prediction(val prob: Double, val distribution: Map[Double, Double]) { + override def toString = { "probability = " + prob + ", distribution = " + distribution } +} + +/* + * Class for storing splits -- feature index and threshold + */ +case class Split(val feature: Int, val threshold: Double) { + override def toString = "feature = " + feature + ", threshold = " + threshold +} + +/* + * Class for storing the split predicate. + */ +class SplitPredicate(val split: Split, lessThanEqualTo: Boolean = true) extends Serializable { + override def toString = "split = " + split.toString + ", lessThan = " + lessThanEqualTo +} + +/* + * Class for building the Decision Tree model. Should be used for both classification and regression tree. + */ +class DecisionTree( + val input: RDD[(Double, Array[Double])], //input RDD + val maxDepth: Int, // depth of the tree + val numSplitPredicates: Int, // number of bins per features + val fraction: Double, // fraction of the data to be used for performing quantile calculation + val strategy: Strategy, // classification or regression + val impurity: Impurity, + val sparkContext : SparkContext) { // impurity calculation strategy (variance, gini, entropy, etc.) + + //Calculating length of the features + val featureLength = input.first._2.length + println("feature length = " + featureLength) + + //Sampling a fraction of the input RDD + val sampledData = input.sample(false, fraction, 42).cache() + + //Sorting the sampled data along each feature and storing it for quantile calculation + val sortedSampledFeatures = { + val sortedFeatureArray = new Array[RDD[Double]](featureLength) + 0 until featureLength foreach { + i => sortedFeatureArray(i) = sampledData.map(x => x._2(i) -> None).sortByKey(true).map(_._1) + } + sortedFeatureArray + } + + val numSamples = sampledData.count + println("num samples = " + numSamples) + + // Calculating the index to jump to find the quantile points + val stride = scala.math.max(numSamples / numSplitPredicates, 1) + println("stride = " + stride) + + //Calculating all possible splits for the features + val allSplitsList = for { + featureIndex <- 0 until featureLength; + index <- stride until numSamples - 1 by stride + } yield createSplit(featureIndex, index) + + //Remove duplicate splits. Especially help for one-hot encoded categorical variables. + val allSplits = sparkContext.broadcast(allSplitsList.toSet) + + //for (split <- allSplits) yield println(split) + + /* + * Find the exact value using feature index and index into the sorted features + */ + def valueAtRDDIndex(featuresIndex: Long, index: Long): Double = { + sortedSampledFeatures(featuresIndex.toInt).collect()(index.toInt) + } + + /* + * Create splits using feature index and index into the sorted features + */ + def createSplit(featureIndex: Int, index: Long): Split = { + new Split(featureIndex, valueAtRDDIndex(featureIndex, index)) + } + + /* + * Empty Node class used to terminate leaf nodes + */ + private class LeafNode(val data: RDD[(Double, Array[Double])]) extends Node { + def isLeaf = true + def left = throw new OperationNotSupportedException("EmptyNode.left") + def right = throw new OperationNotSupportedException("EmptyNode.right") + def depth = throw new OperationNotSupportedException("EmptyNode.depth") + def splitPredicates = throw new OperationNotSupportedException("EmptyNode.splitPredicates") + def splitPredicate = throw new OperationNotSupportedException("EmptyNode.splitPredicate") + override def toString() = "Empty" + val prediction: Prediction = { + val countZero: Double = data.filter(x => (x._1 == 0.0)).count + val countOne: Double = data.filter(x => (x._1 == 1.0)).count + val countTotal: Double = countZero + countOne + new Prediction(countOne / countTotal, Map(0.0 -> countZero, 1.0 -> countOne)) + } + } + + /* + * Top node for building a classification tree + */ + private class TopClassificationNode extends ClassificationNode(input.cache, 0, List[SplitPredicate](), new NodeStats) { + override def toString() = "[" + left + "[" + "TopNode" + "]" + right + "]" + } + + /* + * Class for each node in the classification tree + */ + private class ClassificationNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats) + extends DecisionNode(data, depth, splitPredicates, nodeStats) { + + // Prediction at each classification node + val prediction: Prediction = { + val countZero: Double = data.filter(x => (x._1 == 0.0)).count + val countOne: Double = data.filter(x => (x._1 == 1.0)).count + val countTotal: Double = countZero + countOne + new Prediction(countOne / countTotal, Map(0.0 -> countZero, 1.0 -> countOne)) + } + + //Static factory method. Put it in a better location. + def createNode(anyData: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats) = new ClassificationNode(anyData, depth, splitPredicates, nodeStats) + + } + + /* + * Top node for building a regression tree + */ + private class TopRegressionNode(nodeStats : NodeStats) extends RegressionNode(input.cache, 0, List[SplitPredicate](), nodeStats) { + override def toString() = "[" + left + "[" + "TopNode" + "]" + right + "]" + } + + /* + * Class for each node in the regression tree + */ + private class RegressionNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats) + extends DecisionNode(data, depth, splitPredicates, nodeStats) { + + // Prediction at each regression node + val prediction: Prediction = new Prediction(data.map(_._1).mean, Map()) + + //Static factory method. Put it in a better location. + def createNode(anyData: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats) = new RegressionNode(anyData, depth, splitPredicates, nodeStats) + } + + abstract class DecisionNode( + val data: RDD[(Double, Array[Double])], + val depth: Int, + val splitPredicates: List[SplitPredicate], + val nodeStats : NodeStats) extends Node { + //TODO: Change empty logic + val splits = splitPredicates.map(x => x.split) + //TODO: Think about the merits of doing BFS and removing the parents RDDs from memory instead of doing DFS like below. + val (left, right, splitPredicate, isLeaf) = createLeftRightChild() + override def toString() = "[" + left + "[" + this.splitPredicate + " prediction = " + this.prediction + "]" + right + "]" + def createNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats): DecisionNode + def createLeftRightChild(): (Node, Node, Option[SplitPredicate], Boolean) = { + if (depth > maxDepth) { + (new LeafNode(data), new LeafNode(data), None, true) + } else { + println("split count " + splits.length) + val split_gain = findBestSplit(nodeStats) + val (split, gain, leftNodeStats, rightNodeStats) = split_gain + println("Selected split = " + split + " with gain = " + gain, "left node stats = " + leftNodeStats + " right node stats = " + rightNodeStats) + if (split_gain._2 > 0) { + println("creating new nodes at depth = " + depth) + val leftPredicate = new SplitPredicate(split, true) + val rightPredicate = new SplitPredicate(split, false) + val leftData = data.filter(sample => sample._2(leftPredicate.split.feature) <= leftPredicate.split.threshold).cache + val rightData = data.filter(sample => sample._2(rightPredicate.split.feature) > rightPredicate.split.threshold).cache + val leftNode = if (leftData.count != 0) createNode(leftData, depth + 1, splitPredicates ::: List(leftPredicate), leftNodeStats) else new LeafNode(data) + val rightNode = if (rightData.count != 0) createNode(rightData, depth + 1, splitPredicates ::: List(rightPredicate), rightNodeStats) else new LeafNode(data) + (leftNode, rightNode, Some(leftPredicate), false) + } else { + println("not creating more child nodes since gain is not greater than 0") + (new LeafNode(data), new LeafNode(data), None, true) + } + } + } + + def findBestSplit(nodeStats: NodeStats): (Split, Double, NodeStats, NodeStats) = { + //TODO: Also remove splits that are subsets of previous splits + val availableSplits = allSplits.value filterNot (split => splits contains split) + println("availableSplit count " + availableSplits.size) + //availableSplits.map(split1 => (split1, impurity.calculateGain(split1, data))).reduce(comparePair(_, _)) + + strategy match { + case Strategy("Classification") => { + + val splitWiseCalculations = data.flatMap(sample => { + val label = sample._1 + val features = sample._2 + val leftOrRight = for { + split <- availableSplits.toSeq + featureIndex = split.feature + threshold = split.threshold + } yield { if (features(featureIndex) <= threshold) (split, "left", label) else (split, "right", label) } + leftOrRight + }).map(k => (k, 1)) + + val gainCalculations = splitWiseCalculations.countByKey() + .toMap //TODO: Hack to go from mutable to immutable map. Clean this up if needed. + + val split_gain_list = for ( + split <- availableSplits; + gain = impurity.calculateClassificationGain(split, gainCalculations) + ) yield (split, gain) + + val split_gain = split_gain_list.reduce(comparePair(_, _)) + (split_gain._1, split_gain._2, new NodeStats, new NodeStats) + + } + case Strategy("Regression") => { + + val splitWiseCalculations = data.flatMap(sample => { + val label = sample._1 + val features = sample._2 + val leftOrRight = for { + split <- availableSplits.toSeq + featureIndex = split.feature + threshold = split.threshold + } yield {if (features(featureIndex) <= threshold) ((split, "left"), label) else ((split, "right"), label)} + leftOrRight + }) + + // Calculate variance for each split + val splitVariancePairs = splitWiseCalculations.groupByKey().map(x => x._1 -> ParVariance.calculateVarianceSize(x._2)).collect + //Tuple array to map conversion + val gainCalculations = scala.collection.immutable.Map(splitVariancePairs: _*) + + val split_gain_list = for ( + split <- availableSplits; + (gain, leftNodeStats, rightNodeStats) = impurity.calculateRegressionGain(split, gainCalculations, nodeStats) + ) yield (split, gain, leftNodeStats, rightNodeStats) + + val split_gain = split_gain_list.reduce(compareRegressionPair(_, _)) + (split_gain._1, split_gain._2,split_gain._3, split_gain._4) + } + } + } + + } + + + def comparePair(x: (Split, Double), y: (Split, Double)): (Split, Double) = { + if (x._2 > y._2) x else y + } + + def compareRegressionPair(x: (Split, Double, NodeStats, NodeStats), y: (Split, Double, NodeStats, NodeStats)): (Split, Double, NodeStats, NodeStats) = { + if (x._2 > y._2) x else y + } + + + def buildTree(): Node = { + strategy match { + case Strategy("Classification") => new TopClassificationNode() + case Strategy("Regression") => { + val count = input.count + //TODO: calculate mean and variance together + val variance = input.map(x => x._1).variance + val mean = input.map(x => x._1).mean + val nodeStats = new NodeStats(count = Some(count), variance = Some(variance), mean = Some(mean)) + new TopRegressionNode(nodeStats) + } + } + } + +} + +object ParVariance extends Serializable { + + def calculateVarianceSize(seq: Seq[Double]): (Double, Double, Long) = { + val stat = StatCounter(seq) + (stat.mean, stat.variance, stat.count) + } + +} + + +object DecisionTree { + def train( + input: RDD[(Double, Array[Double])], + numSplitPredicates: Int, + strategy: Strategy, + impurity: Impurity, + maxDepth : Int, + fraction : Double, + sparkContext : SparkContext): Option[NodeModel] = { + new DecisionTree( + input = input, + numSplitPredicates = numSplitPredicates, + strategy = strategy, + impurity = impurity, + maxDepth = maxDepth, + fraction = fraction, + sparkContext = sparkContext) + .buildTree + .extractModel + } +} + +case class Strategy(val name: String) + +class NodeStats( + val gini: Option[Double] = None, + val entropy: Option[Double] = None, + val mean: Option[Double] = None, + val variance: Option[Double] = None, + val count: Option[Long] = None) extends Serializable{ + override def toString = "variance = " + variance + "count = " + count + "mean = " + mean +} + + +trait Impurity { + + def calculateClassificationGain(split: Split, calculations : Map[(Split, String, Double),Long]): Double = { + val leftRddZeroCount = calculations.getOrElse((split,"left",0.0),0L).toDouble; + val rightRddZeroCount = calculations.getOrElse((split,"right",0.0),0L).toDouble; + val leftRddOneCount = calculations.getOrElse((split,"left",1.0),0L).toDouble; + val rightRddOneCount = calculations.getOrElse((split,"right",1.0),0L).toDouble; + val leftRddCount = leftRddZeroCount + leftRddOneCount; + val rightRddCount = rightRddZeroCount + rightRddOneCount; + val totalZeroCount = leftRddZeroCount + rightRddZeroCount; + val totalOneCount = leftRddOneCount + rightRddOneCount; + val totalCount = totalZeroCount + totalOneCount; + val gain = { + if (leftRddCount == 0 || rightRddCount == 0) 0 + else { + val topGini = calculate(totalZeroCount,totalOneCount) + val leftWeight = leftRddCount / totalCount + val leftGini = calculate(leftRddZeroCount,leftRddOneCount) * leftWeight + val rightWeight = rightRddCount / totalCount + val rightGini = calculate(rightRddZeroCount,rightRddOneCount) * rightWeight + topGini - leftGini - rightGini + } + } + gain + } + + def calculateRegressionGain(split: Split, calculations : Map[(Split, String),(Double, Double, Long)], nodeStats : NodeStats): (Double, NodeStats, NodeStats) = { + val topCount = nodeStats.count.get + val leftCount = calculations.getOrElse((split,"left"),(0,0,0L))._3 + val rightCount = calculations.getOrElse((split,"right"),(0,0,0L))._3 + if (leftCount == 0 || rightCount == 0){ + // No gain return values + //println("leftCount = " + leftCount + "rightCount = " + rightCount + " topCount = " + topCount) + (0, new NodeStats, new NodeStats) + } else{ + val topVariance = nodeStats.variance.get + val leftMean = calculations((split,"left"))._1 + val leftVariance = calculations((split,"left"))._2 + val rightMean = calculations((split,"right"))._1 + val rightVariance = calculations((split,"right"))._2 + //TODO: Check and if needed improve these toDouble conversions + val gain = topVariance - ((leftCount.toDouble / topCount) * leftVariance) - ((rightCount.toDouble/topCount) * rightVariance) + (gain, + new NodeStats(mean = Some(leftMean), variance = Some(leftVariance), count = Some(leftCount)), + new NodeStats(mean = Some(rightMean), variance = Some(rightVariance), count = Some(rightCount))) + } + } + + def calculate(c0 : Double, c1 : Double): Double + +} + + +object Gini extends Impurity { + + def calculate(c0 : Double, c1 : Double): Double = { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + 1 - f0*f0 - f1*f1 + } + +} + +object Entropy extends Impurity { + + def log2(x: Double) = scala.math.log(x) / scala.math.log(2) + + def calculate(c0: Double, c1: Double): Double = { + if (c0 == 0 || c1 == 0) { + 0 + } else { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + -(f0 * log2(f0)) - (f1 * log2(f1)) + } + } + +} + +object Variance extends Impurity { + def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate") +} + +object RegressionTreeRunner { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: SparkPi []") + System.exit(1) + } + /**START Experimental*/ + System.setProperty("spark.cores.max", "8") + /**END Experimental*/ + val sc = new SparkContext(args(0), "Decision Tree Runner", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val data = TreeUtils.loadLabeledData(sc, args(1)) + val tree = DecisionTree.train( + input = data, + numSplitPredicates = 1000, + strategy = new Strategy("Regression"), + impurity = Variance, + maxDepth = 1, + fraction = 1, + sparkContext = sc) + println(tree) + println(tree.get.isLeaf) + println("prediction = " + tree.get.predict(Array(1.0, 2.0))) + } +} + +object ClassificationGiniTreeRunner { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: SparkPi []") + System.exit(1) + } + /**START Experimental*/ + System.setProperty("spark.cores.max", "8") + /**END Experimental*/ + val sc = new SparkContext(args(0), "Decision Tree Runner", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val data = TreeUtils.loadLabeledData(sc, args(1)) + val tree = DecisionTree.train( + input = data, + numSplitPredicates = 1000, + strategy = new Strategy("Classification"), + impurity = Gini, + maxDepth = 1, + fraction = 1, + sparkContext = sc) + println(tree) + println(tree.get.isLeaf) + println("prediction = " + tree.get.predict(Array(1.0, 2.0))) + } +} + +object ClassificationEntropyTreeRunner { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: SparkPi []") + System.exit(1) + } + /**START Experimental*/ + System.setProperty("spark.cores.max", "8") + /**END Experimental*/ + val sc = new SparkContext(args(0), "Decision Tree Runner", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val data = TreeUtils.loadLabeledData(sc, args(1)) + val tree = DecisionTree.train( + input = data, + numSplitPredicates = 1000, + strategy = new Strategy("Classification"), + impurity = Entropy, + maxDepth = 1, + fraction = 1, + sparkContext = sc) + println(tree) + println(tree.get.isLeaf) + println("prediction = " + tree.get.predict(Array(1.0, 2.0))) + } +} + +/** + * Helper methods to load and save data + * Data format: + * , ... + * where , are feature values in Double and is the corresponding label as Double. + */ +object TreeUtils { + + /** + * @param sc SparkContext + * @param dir Directory to the input data files. + * @return An RDD of tuples. For each tuple, the first element is the label, and the second + * element represents the feature values (an array of Double). + */ + def loadLabeledData(sc: SparkContext, dir: String): RDD[(Double, Array[Double])] = { + sc.textFile(dir).map { line => + val parts = line.trim().split(",") + val label = parts(0).toDouble + //val features = parts(1).trim().split(",").map(_.toDouble) + //val features = parts.slice(1,parts.length).map(_.toDouble) + val features = parts.slice(1, 3).map(_.toDouble) + (label, features) + } + } + + def saveLabeledData(data: RDD[(Double, Array[Double])], dir: String) { + val dataStr = data.map(x => x._1 + "," + x._2.mkString(" ")) + dataStr.saveAsTextFile(dir) + } + +} From e2231adcf8a5fdd419831213365be00f633c8810 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 28 Sep 2013 00:05:48 -0700 Subject: [PATCH 02/19] added design file --- src/main/scala/ml/tree/design.md | 83 ++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 src/main/scala/ml/tree/design.md diff --git a/src/main/scala/ml/tree/design.md b/src/main/scala/ml/tree/design.md new file mode 100644 index 0000000..17825ca --- /dev/null +++ b/src/main/scala/ml/tree/design.md @@ -0,0 +1,83 @@ +#Tree design doc +Decision tree classifiers are both popular supervised learning algorithms and also building blocks for other ensemble learning algorithms such as random forests, boosting, etc. This document discuss the design for its implementation in the Spark project. + +**The current design will be optimized for the scenario where all the data can be fit into the in-cluster memory.** + +##Algorithm +Decision tree classifier is formed by creating recursive binary partitions using the optimal splitting criterion that maximizes the information gain at each step. It handles both ordered (numeric) and unordered (categorial) features. + +###Identifying Split Predicates +The split predicates will be calculated by performing a single pass over the data at the start of the tree model building. The binning of the data can be performed using two techniques: + +1. Sorting the ordered features and finding the exact quantile points. Complexity: O(N*logN) * #features +2. Using an [approximate quantile calculation algorithm](http://infolab.stanford.edu/~manku/papers/99sigmod-unknown.pdf) cited by the PLANET paper. + +###Optimal Splitting Criterion +The splitting criterion is calculated using one of two popular criterion: + +1. [Gini impurity](http://en.wikipedia.org/wiki/Gini_coefficient) +2. [Entropy](http://en.wikipedia.org/wiki/Information_gain_in_decision_trees) + +Each split is stored in a model for future predictions. + +###Stopping criterion +There are various criterion that can be used to stop adding more levels to the tree. The first implementation will be kept simple and will use the following criteria : no further information gain can be achieved or the maximum depth has been reached. Once a stopping criteria is met, the current node is a leaf of the tree and updates the model with the distribution of the remaining classes at the node. + +###Prediction +To make a prediction, a new sample is run through the decision tree model till it arrives at a leaf node. Upon reaching the leaf node, a prediction is made using the distribution of the underlying samples. (typically, the distribution itself is the output) + +##Implementation + +###Code +For a consistent API, the training code will be consistent with the existing logistic regressions algorithms for supervised learning. + +The train() method will take be of the following format + + def train(input: RDD[(Double, Array[Double])]): DecisionTreeModel = {...} + def predict(testData: spark.RDD[Array[Double]]) = {...} + +All splitting criterion can be evaluated in parallel using the *map* operation. The *reduce* operation will select the best splitting criterion. The split criterion will create a *filter* that should be applied to the RDD at each node to derive the RDDs at the next node. + +The pseudocode is given below: + + def train(input: RDD[(Double, Array[Double])]): DecisionTreeModel = { + filterList = new List() + root = new Node() + buildTree(root,input,filterList) + } + + def buildTree(node : Node, input : RDD[(Double, Array[Double])]), filterList : List) : Tree = { + splits = find_possible_splits(input) + bestSplit = splits.map( split => calculateInformationGain(input, split)).reduce(_ max _) + if (bestSplit > threshold){ + leftRDD = RDD.filter(sample => sample.satisfiesSplit(bestSplit)) + rightRDD = RDD.filter(sample => sample.satisfiesSplit(bestSplit.invert())) + node.split = bestSplit + node.left = new Node() + node.right = new Node() + lefttree = buildTree(node.left,leftRDD,filterList.add(bestSplit)) + righttree = buildTree(node.right,rightRDD,filterList.add(bestSplit.invert())) + } + node + } + +###Testing + +#####Unit testing +As a standard programming practice, unit tests will be written to test the important building blocks. + +####Comparison with other libraries +There are several machine learning libraries in other languages. The scikit-learn library will be used a benchmark for functional tests. + +###Constraints ++ Two class labels -- The first implementation will support only binary labels. ++ Class weights -- Class weighting option (useful for highly unblanaced data) will not be supported ++ Sanity checks -- The input data sanity checks will not be performed. Ideally, a separate pre-processing step (that that is common to all ML algorithms) should handle this. + +## Future Work ++ Weights to handle unbalanced classes ++ Ensemble methods -- random forests, boosting, etc. + +##References +1. Hastie, Tibshirani, Friedman. The Elements of Statistical Learning: Data Mining, Inference, and Prediction. Springer 2009. +2. Biswanath, Herbach, Basu and Roberto. PLANET: Massively Parallel Learning of Tree Ensembles with MapReduce, VLDB 2009. \ No newline at end of file From 95d45ab1c083f5563a3aeb50f0b89b42055b0ccd Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 30 Sep 2013 00:07:00 -0700 Subject: [PATCH 03/19] added accuracy score calculation --- src/main/scala/ml/tree/DecisionTree.scala | 140 +++++++++++----------- 1 file changed, 72 insertions(+), 68 deletions(-) diff --git a/src/main/scala/ml/tree/DecisionTree.scala b/src/main/scala/ml/tree/DecisionTree.scala index 6c51795..0cba862 100644 --- a/src/main/scala/ml/tree/DecisionTree.scala +++ b/src/main/scala/ml/tree/DecisionTree.scala @@ -22,6 +22,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.classification.ClassificationModel import org.apache.spark.SparkContext import org.apache.spark.util.StatCounter +import org.apache.spark.Logging /* @@ -106,7 +107,7 @@ class NodeModel( /* * Class used to store the prediction values at each node of the tree. */ -class Prediction(val prob: Double, val distribution: Map[Double, Double]) { +class Prediction(val prob: Double, val distribution: Map[Double, Double]) extends Serializable { override def toString = { "probability = " + prob + ", distribution = " + distribution } } @@ -127,14 +128,14 @@ class SplitPredicate(val split: Split, lessThanEqualTo: Boolean = true) extends /* * Class for building the Decision Tree model. Should be used for both classification and regression tree. */ -class DecisionTree( +class DecisionTree ( val input: RDD[(Double, Array[Double])], //input RDD val maxDepth: Int, // depth of the tree val numSplitPredicates: Int, // number of bins per features val fraction: Double, // fraction of the data to be used for performing quantile calculation val strategy: Strategy, // classification or regression - val impurity: Impurity, - val sparkContext : SparkContext) { // impurity calculation strategy (variance, gini, entropy, etc.) + val impurity: Impurity, // impurity calculation strategy (variance, gini, entropy, etc.) + val sparkContext : SparkContext) { //Calculating length of the features val featureLength = input.first._2.length @@ -502,84 +503,88 @@ object Variance extends Impurity { def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate") } -object RegressionTreeRunner { +object TreeRunner extends Logging { + val usage = """ + Usage: DecisionTreeRunner [slices] --strategy --dataDirectory directory [--maxDepth num] [--impurity ] [--samplingFractionForSplitCalculation num] + """ + def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkPi []") - System.exit(1) - } - /**START Experimental*/ - System.setProperty("spark.cores.max", "8") - /**END Experimental*/ - val sc = new SparkContext(args(0), "Decision Tree Runner", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val data = TreeUtils.loadLabeledData(sc, args(1)) - val tree = DecisionTree.train( - input = data, - numSplitPredicates = 1000, - strategy = new Strategy("Regression"), - impurity = Variance, - maxDepth = 1, - fraction = 1, - sparkContext = sc) - println(tree) - println(tree.get.isLeaf) - println("prediction = " + tree.get.predict(Array(1.0, 2.0))) - } -} -object ClassificationGiniTreeRunner { - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkPi []") - System.exit(1) - } + if (args.length < 2) { + System.err.println(usage) + System.exit(1) + } + /**START Experimental*/ System.setProperty("spark.cores.max", "8") /**END Experimental*/ val sc = new SparkContext(args(0), "Decision Tree Runner", System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val data = TreeUtils.loadLabeledData(sc, args(1)) - val tree = DecisionTree.train( - input = data, - numSplitPredicates = 1000, - strategy = new Strategy("Classification"), - impurity = Gini, - maxDepth = 1, - fraction = 1, - sparkContext = sc) - println(tree) - println(tree.get.isLeaf) - println("prediction = " + tree.get.predict(Array(1.0, 2.0))) - } -} -object ClassificationEntropyTreeRunner { - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkPi []") - System.exit(1) + + val arglist = args.toList.drop(1) + type OptionMap = Map[Symbol, Any] + + def nextOption(map : OptionMap, list: List[String]) : OptionMap = { + def isSwitch(s : String) = (s(0) == '-') + list match { + case Nil => map + case "--strategy" :: string :: tail => nextOption(map ++ Map('strategy -> string), tail) + case "--dataDirectory" :: string :: tail => nextOption(map ++ Map('dataDirectory -> string), tail) + case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) + case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) + case "--samplingFractionForSplitCalculation" :: string :: tail => nextOption(map ++ Map('samplingFractionForSplitCalculation -> string), tail) + case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) + case option :: tail => println("Unknown option "+option) + exit(1) + } } - /**START Experimental*/ - System.setProperty("spark.cores.max", "8") - /**END Experimental*/ - val sc = new SparkContext(args(0), "Decision Tree Runner", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val data = TreeUtils.loadLabeledData(sc, args(1)) + val options = nextOption(Map(),arglist) + println(options) + //TODO: Add check for acceptable string inputs + + val data = TreeUtils.loadLabeledData(sc, options.get('dataDirectory).get.toString) + val strategyStr = options.get('strategy).get.toString + val impurityStr = options.getOrElse('impurity,"Gini").toString + val impurity = { + impurityStr match { + case "Gini" => Gini + case "Entropy" => Entropy + case "Variance" => Variance + } + } + val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt + val fraction = options.getOrElse('samplingFractionForSplitCalculation,"1.0").toString.toDouble + val tree = DecisionTree.train( input = data, numSplitPredicates = 1000, - strategy = new Strategy("Classification"), - impurity = Entropy, - maxDepth = 1, - fraction = 1, + strategy = new Strategy(strategyStr), + impurity = impurity, + maxDepth = maxDepth, + fraction = fraction, sparkContext = sc) println(tree) - println(tree.get.isLeaf) - println("prediction = " + tree.get.predict(Array(1.0, 2.0))) + //println("prediction = " + tree.get.predict(Array(1.0, 2.0))) + + val trainingError = accuracyScore(tree, data) + print("accuracy score on training data = " + trainingError) + } + + def accuracyScore(tree : Option[ml.tree.NodeModel], data : RDD[(Double, Array[Double])]) : Double = { + if (tree.isEmpty) return 1 + val correctCount = data.filter(y => tree.get.predict(y._2) == y._1).count() + val count = data.count() + print("correct count = " + correctCount) + print("training data count = " + count) + correctCount.toDouble / count + } + + } + /** * Helper methods to load and save data * Data format: @@ -598,9 +603,8 @@ object TreeUtils { sc.textFile(dir).map { line => val parts = line.trim().split(",") val label = parts(0).toDouble - //val features = parts(1).trim().split(",").map(_.toDouble) - //val features = parts.slice(1,parts.length).map(_.toDouble) - val features = parts.slice(1, 3).map(_.toDouble) + val features = parts.slice(1,parts.length).map(_.toDouble) + //val features = parts.slice(1, 30).map(_.toDouble) (label, features) } } From 6754a40d6c8b58522831abf1142e2149c18bbf36 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 30 Sep 2013 23:19:54 -0700 Subject: [PATCH 04/19] added mean square error and test directory --- src/main/scala/ml/tree/DecisionTree.scala | 32 ++++++++++++++----- .../scala/ml/tree/{design.md => README.md} | 0 2 files changed, 24 insertions(+), 8 deletions(-) rename src/main/scala/ml/tree/{design.md => README.md} (100%) diff --git a/src/main/scala/ml/tree/DecisionTree.scala b/src/main/scala/ml/tree/DecisionTree.scala index 0cba862..c4db136 100644 --- a/src/main/scala/ml/tree/DecisionTree.scala +++ b/src/main/scala/ml/tree/DecisionTree.scala @@ -505,7 +505,7 @@ object Variance extends Impurity { object TreeRunner extends Logging { val usage = """ - Usage: DecisionTreeRunner [slices] --strategy --dataDirectory directory [--maxDepth num] [--impurity ] [--samplingFractionForSplitCalculation num] + Usage: DecisionTreeRunner [slices] --strategy --trainDataDir path --testDataDir path [--maxDepth num] [--impurity ] [--samplingFractionForSplitCalculation num] """ def main(args: Array[String]) { @@ -530,7 +530,8 @@ object TreeRunner extends Logging { list match { case Nil => map case "--strategy" :: string :: tail => nextOption(map ++ Map('strategy -> string), tail) - case "--dataDirectory" :: string :: tail => nextOption(map ++ Map('dataDirectory -> string), tail) + case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string), tail) + case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), tail) case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) case "--samplingFractionForSplitCalculation" :: string :: tail => nextOption(map ++ Map('samplingFractionForSplitCalculation -> string), tail) @@ -543,7 +544,7 @@ object TreeRunner extends Logging { println(options) //TODO: Add check for acceptable string inputs - val data = TreeUtils.loadLabeledData(sc, options.get('dataDirectory).get.toString) + val trainData = TreeUtils.loadLabeledData(sc, options.get('trainDataDir).get.toString) val strategyStr = options.get('strategy).get.toString val impurityStr = options.getOrElse('impurity,"Gini").toString val impurity = { @@ -557,7 +558,7 @@ object TreeRunner extends Logging { val fraction = options.getOrElse('samplingFractionForSplitCalculation,"1.0").toString.toDouble val tree = DecisionTree.train( - input = data, + input = trainData, numSplitPredicates = 1000, strategy = new Strategy(strategyStr), impurity = impurity, @@ -567,20 +568,35 @@ object TreeRunner extends Logging { println(tree) //println("prediction = " + tree.get.predict(Array(1.0, 2.0))) - val trainingError = accuracyScore(tree, data) - print("accuracy score on training data = " + trainingError) + val testData = TreeUtils.loadLabeledData(sc, options.get('testDataDir).get.toString) + + + val testError = { + strategyStr match { + case "Classification" => accuracyScore(tree, testData) + case "Regression" => meanSquaredError(tree, testData) + } + } + print("error = " + testError) } def accuracyScore(tree : Option[ml.tree.NodeModel], data : RDD[(Double, Array[Double])]) : Double = { - if (tree.isEmpty) return 1 + if (tree.isEmpty) return 1 //TODO: Throw exception val correctCount = data.filter(y => tree.get.predict(y._2) == y._1).count() val count = data.count() print("correct count = " + correctCount) print("training data count = " + count) correctCount.toDouble / count } - + + def meanSquaredError(tree : Option[ml.tree.NodeModel], data : RDD[(Double, Array[Double])]) : Double = { + if (tree.isEmpty) return 1 //TODO: Throw exception + val meanSumOfSquares = data.map(y => (tree.get.predict(y._2) - y._1)*(tree.get.predict(y._2) - y._1)).mean + print("meanSumOfSquares = " + meanSumOfSquares) + meanSumOfSquares + } + } diff --git a/src/main/scala/ml/tree/design.md b/src/main/scala/ml/tree/README.md similarity index 100% rename from src/main/scala/ml/tree/design.md rename to src/main/scala/ml/tree/README.md From 49b87972e54894c1a7dbc23852462dcf513c8fae Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 30 Sep 2013 23:20:43 -0700 Subject: [PATCH 05/19] placeholder test file --- src/test/scala/ml/tree/DecisionTreeTest.scala | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 src/test/scala/ml/tree/DecisionTreeTest.scala diff --git a/src/test/scala/ml/tree/DecisionTreeTest.scala b/src/test/scala/ml/tree/DecisionTreeTest.scala new file mode 100644 index 0000000..d288c7a --- /dev/null +++ b/src/test/scala/ml/tree/DecisionTreeTest.scala @@ -0,0 +1,8 @@ +package ml.tree +import org.scalatest.FunSuite + +class DecisionTreeTest extends FunSuite { + test("Basic decision tree test") { + //Decision Tree test + } +} From 376e241dcf14f2e0787d818ace1bb602001cb112 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 6 Oct 2013 23:12:36 -0700 Subject: [PATCH 06/19] basic documentation --- src/main/scala/ml/tree/README.md | 86 +++----------------------------- src/main/scala/ml/tree/design.md | 83 ++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 79 deletions(-) create mode 100644 src/main/scala/ml/tree/design.md diff --git a/src/main/scala/ml/tree/README.md b/src/main/scala/ml/tree/README.md index 17825ca..e5173d3 100644 --- a/src/main/scala/ml/tree/README.md +++ b/src/main/scala/ml/tree/README.md @@ -1,83 +1,11 @@ #Tree design doc -Decision tree classifiers are both popular supervised learning algorithms and also building blocks for other ensemble learning algorithms such as random forests, boosting, etc. This document discuss the design for its implementation in the Spark project. +Decision tree classifiers are both popular supervised learning algorithms and also building blocks for other ensemble learning algorithms such as random forests, boosting, etc. This document discusses its implementation in the Spark project. -**The current design will be optimized for the scenario where all the data can be fit into the in-cluster memory.** +#Usage +DecisionTreeRunner [slices] --strategy --trainDataDir path --testDataDir path [--maxDepth num] [--impurity ] [--samplingFractionForSplitCalculation num] + -##Algorithm -Decision tree classifier is formed by creating recursive binary partitions using the optimal splitting criterion that maximizes the information gain at each step. It handles both ordered (numeric) and unordered (categorial) features. +#Example +sbt/sbt "run-main ml.tree.TreeRunner local[2] --strategy Classification --trainDataDir ../train_data --testDataDir ../test_data --maxDepth 1 --impurity Gini --samplingFractionForSplitCalculation 1" -###Identifying Split Predicates -The split predicates will be calculated by performing a single pass over the data at the start of the tree model building. The binning of the data can be performed using two techniques: - -1. Sorting the ordered features and finding the exact quantile points. Complexity: O(N*logN) * #features -2. Using an [approximate quantile calculation algorithm](http://infolab.stanford.edu/~manku/papers/99sigmod-unknown.pdf) cited by the PLANET paper. - -###Optimal Splitting Criterion -The splitting criterion is calculated using one of two popular criterion: - -1. [Gini impurity](http://en.wikipedia.org/wiki/Gini_coefficient) -2. [Entropy](http://en.wikipedia.org/wiki/Information_gain_in_decision_trees) - -Each split is stored in a model for future predictions. - -###Stopping criterion -There are various criterion that can be used to stop adding more levels to the tree. The first implementation will be kept simple and will use the following criteria : no further information gain can be achieved or the maximum depth has been reached. Once a stopping criteria is met, the current node is a leaf of the tree and updates the model with the distribution of the remaining classes at the node. - -###Prediction -To make a prediction, a new sample is run through the decision tree model till it arrives at a leaf node. Upon reaching the leaf node, a prediction is made using the distribution of the underlying samples. (typically, the distribution itself is the output) - -##Implementation - -###Code -For a consistent API, the training code will be consistent with the existing logistic regressions algorithms for supervised learning. - -The train() method will take be of the following format - - def train(input: RDD[(Double, Array[Double])]): DecisionTreeModel = {...} - def predict(testData: spark.RDD[Array[Double]]) = {...} - -All splitting criterion can be evaluated in parallel using the *map* operation. The *reduce* operation will select the best splitting criterion. The split criterion will create a *filter* that should be applied to the RDD at each node to derive the RDDs at the next node. - -The pseudocode is given below: - - def train(input: RDD[(Double, Array[Double])]): DecisionTreeModel = { - filterList = new List() - root = new Node() - buildTree(root,input,filterList) - } - - def buildTree(node : Node, input : RDD[(Double, Array[Double])]), filterList : List) : Tree = { - splits = find_possible_splits(input) - bestSplit = splits.map( split => calculateInformationGain(input, split)).reduce(_ max _) - if (bestSplit > threshold){ - leftRDD = RDD.filter(sample => sample.satisfiesSplit(bestSplit)) - rightRDD = RDD.filter(sample => sample.satisfiesSplit(bestSplit.invert())) - node.split = bestSplit - node.left = new Node() - node.right = new Node() - lefttree = buildTree(node.left,leftRDD,filterList.add(bestSplit)) - righttree = buildTree(node.right,rightRDD,filterList.add(bestSplit.invert())) - } - node - } - -###Testing - -#####Unit testing -As a standard programming practice, unit tests will be written to test the important building blocks. - -####Comparison with other libraries -There are several machine learning libraries in other languages. The scikit-learn library will be used a benchmark for functional tests. - -###Constraints -+ Two class labels -- The first implementation will support only binary labels. -+ Class weights -- Class weighting option (useful for highly unblanaced data) will not be supported -+ Sanity checks -- The input data sanity checks will not be performed. Ideally, a separate pre-processing step (that that is common to all ML algorithms) should handle this. - -## Future Work -+ Weights to handle unbalanced classes -+ Ensemble methods -- random forests, boosting, etc. - -##References -1. Hastie, Tibshirani, Friedman. The Elements of Statistical Learning: Data Mining, Inference, and Prediction. Springer 2009. -2. Biswanath, Herbach, Basu and Roberto. PLANET: Massively Parallel Learning of Tree Ensembles with MapReduce, VLDB 2009. \ No newline at end of file +This command will create a decision tree model using the training data in the *trainDataDir* and calculate test error using the data in the *testDataDir*. The mis-classification error is calculated for a Classification *strategy* and mean squared error is calculated for the Regression *strategy*. \ No newline at end of file diff --git a/src/main/scala/ml/tree/design.md b/src/main/scala/ml/tree/design.md new file mode 100644 index 0000000..7864978 --- /dev/null +++ b/src/main/scala/ml/tree/design.md @@ -0,0 +1,83 @@ +#Tree design doc +Decision tree classifiers are both popular supervised learning algorithms and also building blocks for other ensemble learning algorithms such as random forests, boosting, etc. This document discusses its implementation in the Spark project. + +**The current design will be optimized for the scenario where all the data can be fit into the in-cluster memory.** + +##Algorithm +Decision tree classifier is formed by creating recursive binary partitions using the optimal splitting criterion that maximizes the information gain at each step. It handles both ordered (numeric) and unordered (categorial) features. + +###Identifying Split Predicates +The split predicates will be calculated by performing a single pass over the data at the start of the tree model building. The binning of the data can be performed using two techniques: + +1. Sorting the ordered features and finding the exact quantile points. Complexity: O(N*logN) * #features +2. Using an [approximate quantile calculation algorithm](http://infolab.stanford.edu/~manku/papers/99sigmod-unknown.pdf) cited by the PLANET paper. + +###Optimal Splitting Criterion +The splitting criterion is calculated using one of two popular criterion: + +1. [Gini impurity](http://en.wikipedia.org/wiki/Gini_coefficient) +2. [Entropy](http://en.wikipedia.org/wiki/Information_gain_in_decision_trees) + +Each split is stored in a model for future predictions. + +###Stopping criterion +There are various criterion that can be used to stop adding more levels to the tree. The first implementation will be kept simple and will use the following criteria : no further information gain can be achieved or the maximum depth has been reached. Once a stopping criteria is met, the current node is a leaf of the tree and updates the model with the distribution of the remaining classes at the node. + +###Prediction +To make a prediction, a new sample is run through the decision tree model till it arrives at a leaf node. Upon reaching the leaf node, a prediction is made using the distribution of the underlying samples. (typically, the distribution itself is the output) + +##Implementation + +###Code +For a consistent API, the training code will be consistent with the existing logistic regressions algorithms for supervised learning. + +The train() method will take be of the following format + + def train(input: RDD[(Double, Array[Double])]): DecisionTreeModel = {...} + def predict(testData: spark.RDD[Array[Double]]) = {...} + +All splitting criterion can be evaluated in parallel using the *map* operation. The *reduce* operation will select the best splitting criterion. The split criterion will create a *filter* that should be applied to the RDD at each node to derive the RDDs at the next node. + +The pseudocode is given below: + + def train(input: RDD[(Double, Array[Double])]): DecisionTreeModel = { + filterList = new List() + root = new Node() + buildTree(root,input,filterList) + } + + def buildTree(node : Node, input : RDD[(Double, Array[Double])]), filterList : List) : Tree = { + splits = find_possible_splits(input) + bestSplit = splits.map( split => calculateInformationGain(input, split)).reduce(_ max _) + if (bestSplit > threshold){ + leftRDD = RDD.filter(sample => sample.satisfiesSplit(bestSplit)) + rightRDD = RDD.filter(sample => sample.satisfiesSplit(bestSplit.invert())) + node.split = bestSplit + node.left = new Node() + node.right = new Node() + lefttree = buildTree(node.left,leftRDD,filterList.add(bestSplit)) + righttree = buildTree(node.right,rightRDD,filterList.add(bestSplit.invert())) + } + node + } + +###Testing + +#####Unit testing +As a standard programming practice, unit tests will be written to test the important building blocks. + +####Comparison with other libraries +There are several machine learning libraries in other languages. The scikit-learn library will be used a benchmark for functional tests. + +###Constraints ++ Two class labels -- The first implementation will support only binary labels. ++ Class weights -- Class weighting option (useful for highly unblanaced data) will not be supported ++ Sanity checks -- The input data sanity checks will not be performed. Ideally, a separate pre-processing step (that that is common to all ML algorithms) should handle this. + +## Future Work ++ Weights to handle unbalanced classes ++ Ensemble methods -- random forests, boosting, etc. + +##References +1. Hastie, Tibshirani, Friedman. The Elements of Statistical Learning: Data Mining, Inference, and Prediction. Springer 2009. +2. Biswanath, Herbach, Basu and Roberto. PLANET: Massively Parallel Learning of Tree Ensembles with MapReduce, VLDB 2009. \ No newline at end of file From 35cebe69bfa2e5893547f3c9324ddc2db080c424 Mon Sep 17 00:00:00 2001 From: manishamde Date: Sun, 6 Oct 2013 23:20:30 -0700 Subject: [PATCH 07/19] Update README.md Added basic documentation. --- src/main/scala/ml/tree/README.md | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/main/scala/ml/tree/README.md b/src/main/scala/ml/tree/README.md index e5173d3..93fa950 100644 --- a/src/main/scala/ml/tree/README.md +++ b/src/main/scala/ml/tree/README.md @@ -1,4 +1,4 @@ -#Tree design doc +#Decision Tree Decision tree classifiers are both popular supervised learning algorithms and also building blocks for other ensemble learning algorithms such as random forests, boosting, etc. This document discusses its implementation in the Spark project. #Usage @@ -8,4 +8,19 @@ DecisionTreeRunner [slices] --strategy --tra #Example sbt/sbt "run-main ml.tree.TreeRunner local[2] --strategy Classification --trainDataDir ../train_data --testDataDir ../test_data --maxDepth 1 --impurity Gini --samplingFractionForSplitCalculation 1" -This command will create a decision tree model using the training data in the *trainDataDir* and calculate test error using the data in the *testDataDir*. The mis-classification error is calculated for a Classification *strategy* and mean squared error is calculated for the Regression *strategy*. \ No newline at end of file +This command will create a decision tree model using the training data in the *trainDataDir* and calculate test error using the data in the *testDataDir*. The mis-classification error is calculated for a Classification *strategy* and mean squared error is calculated for the Regression *strategy*. + +#Performance testing +To be done + +#Improvements +* Print to dot files +* Unit tests +* Change fractions to quantiles +* Add logging +* Move metrics to a different package + +#Extensions +* Extremely randomized trees +* Random forest +* Boosting From 9ad0dd5274997adc8f69a867d7ea7d7525796850 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 12 Oct 2013 19:11:15 -0700 Subject: [PATCH 08/19] adding empty lines above comments --- .gitignore | 4 ++++ src/main/scala/ml/tree/DecisionTree.scala | 13 +++++++++++++ 2 files changed, 17 insertions(+) diff --git a/.gitignore b/.gitignore index 700ce67..768bf87 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,7 @@ project/plugins/project/ #Eclipse specific .classpath .project + +#IDEA specific +.idea +.idea_modules diff --git a/src/main/scala/ml/tree/DecisionTree.scala b/src/main/scala/ml/tree/DecisionTree.scala index c4db136..7f1da78 100644 --- a/src/main/scala/ml/tree/DecisionTree.scala +++ b/src/main/scala/ml/tree/DecisionTree.scala @@ -29,19 +29,26 @@ import org.apache.spark.Logging * Abstract Node class as a template for implementing various types of nodes in the decision tree. */ abstract class Node { + //Method for checking whether the class has any left/right child nodes. def isLeaf: Boolean + //Left/Right child nodes def left: Node def right: Node + //Depth of the node from the top node def depth: Int + //RDD data as an input to the node def data: RDD[(Double, Array[Double])] + //List of split predicates applied to the base RDD thus far def splitPredicates: List[SplitPredicate] + // Split to arrive at the node def splitPredicate: Option[SplitPredicate] + //Extract model def extractModel: Option[NodeModel] = { //Add probability logic @@ -51,6 +58,8 @@ abstract class Node { Some(new NodeModel(None, None, None, depth, isLeaf, Some(prediction))) } } + + //Prediction at the node def prediction: Prediction } @@ -88,6 +97,7 @@ class NodeModel( * @return Int prediction from the trained model */ def predict(testData: Array[Double]): Double = { + //TODO: Modify this logic to handle regression val pred = prediction.get if (this.isLeaf) { @@ -255,8 +265,10 @@ class DecisionTree ( val depth: Int, val splitPredicates: List[SplitPredicate], val nodeStats : NodeStats) extends Node { + //TODO: Change empty logic val splits = splitPredicates.map(x => x.split) + //TODO: Think about the merits of doing BFS and removing the parents RDDs from memory instead of doing DFS like below. val (left, right, splitPredicate, isLeaf) = createLeftRightChild() override def toString() = "[" + left + "[" + this.splitPredicate + " prediction = " + this.prediction + "]" + right + "]" @@ -286,6 +298,7 @@ class DecisionTree ( } def findBestSplit(nodeStats: NodeStats): (Split, Double, NodeStats, NodeStats) = { + //TODO: Also remove splits that are subsets of previous splits val availableSplits = allSplits.value filterNot (split => splits contains split) println("availableSplit count " + availableSplits.size) From fc773de3147e1ad1089349bef955dada6fcce900 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 12 Oct 2013 19:43:41 -0700 Subject: [PATCH 09/19] moved impurity classes to a different package --- src/main/scala/ml/tree/DecisionTree.scala | 80 +------------------ src/main/scala/ml/tree/impurity/Entropy.scala | 18 +++++ src/main/scala/ml/tree/impurity/Gini.scala | 12 +++ .../scala/ml/tree/impurity/Impurity.scala | 55 +++++++++++++ .../scala/ml/tree/impurity/Variance.scala | 7 ++ 5 files changed, 93 insertions(+), 79 deletions(-) create mode 100644 src/main/scala/ml/tree/impurity/Entropy.scala create mode 100644 src/main/scala/ml/tree/impurity/Gini.scala create mode 100644 src/main/scala/ml/tree/impurity/Impurity.scala create mode 100644 src/main/scala/ml/tree/impurity/Variance.scala diff --git a/src/main/scala/ml/tree/DecisionTree.scala b/src/main/scala/ml/tree/DecisionTree.scala index 7f1da78..74ccac3 100644 --- a/src/main/scala/ml/tree/DecisionTree.scala +++ b/src/main/scala/ml/tree/DecisionTree.scala @@ -23,6 +23,7 @@ import org.apache.spark.mllib.classification.ClassificationModel import org.apache.spark.SparkContext import org.apache.spark.util.StatCounter import org.apache.spark.Logging +import ml.tree.impurity.{Variance, Entropy, Gini, Impurity} /* @@ -431,90 +432,11 @@ class NodeStats( } -trait Impurity { - - def calculateClassificationGain(split: Split, calculations : Map[(Split, String, Double),Long]): Double = { - val leftRddZeroCount = calculations.getOrElse((split,"left",0.0),0L).toDouble; - val rightRddZeroCount = calculations.getOrElse((split,"right",0.0),0L).toDouble; - val leftRddOneCount = calculations.getOrElse((split,"left",1.0),0L).toDouble; - val rightRddOneCount = calculations.getOrElse((split,"right",1.0),0L).toDouble; - val leftRddCount = leftRddZeroCount + leftRddOneCount; - val rightRddCount = rightRddZeroCount + rightRddOneCount; - val totalZeroCount = leftRddZeroCount + rightRddZeroCount; - val totalOneCount = leftRddOneCount + rightRddOneCount; - val totalCount = totalZeroCount + totalOneCount; - val gain = { - if (leftRddCount == 0 || rightRddCount == 0) 0 - else { - val topGini = calculate(totalZeroCount,totalOneCount) - val leftWeight = leftRddCount / totalCount - val leftGini = calculate(leftRddZeroCount,leftRddOneCount) * leftWeight - val rightWeight = rightRddCount / totalCount - val rightGini = calculate(rightRddZeroCount,rightRddOneCount) * rightWeight - topGini - leftGini - rightGini - } - } - gain - } - - def calculateRegressionGain(split: Split, calculations : Map[(Split, String),(Double, Double, Long)], nodeStats : NodeStats): (Double, NodeStats, NodeStats) = { - val topCount = nodeStats.count.get - val leftCount = calculations.getOrElse((split,"left"),(0,0,0L))._3 - val rightCount = calculations.getOrElse((split,"right"),(0,0,0L))._3 - if (leftCount == 0 || rightCount == 0){ - // No gain return values - //println("leftCount = " + leftCount + "rightCount = " + rightCount + " topCount = " + topCount) - (0, new NodeStats, new NodeStats) - } else{ - val topVariance = nodeStats.variance.get - val leftMean = calculations((split,"left"))._1 - val leftVariance = calculations((split,"left"))._2 - val rightMean = calculations((split,"right"))._1 - val rightVariance = calculations((split,"right"))._2 - //TODO: Check and if needed improve these toDouble conversions - val gain = topVariance - ((leftCount.toDouble / topCount) * leftVariance) - ((rightCount.toDouble/topCount) * rightVariance) - (gain, - new NodeStats(mean = Some(leftMean), variance = Some(leftVariance), count = Some(leftCount)), - new NodeStats(mean = Some(rightMean), variance = Some(rightVariance), count = Some(rightCount))) - } - } - - def calculate(c0 : Double, c1 : Double): Double - -} - - -object Gini extends Impurity { - def calculate(c0 : Double, c1 : Double): Double = { - val total = c0 + c1 - val f0 = c0 / total - val f1 = c1 / total - 1 - f0*f0 - f1*f1 - } -} -object Entropy extends Impurity { - def log2(x: Double) = scala.math.log(x) / scala.math.log(2) - def calculate(c0: Double, c1: Double): Double = { - if (c0 == 0 || c1 == 0) { - 0 - } else { - val total = c0 + c1 - val f0 = c0 / total - val f1 = c1 / total - -(f0 * log2(f0)) - (f1 * log2(f1)) - } - } - -} - -object Variance extends Impurity { - def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate") -} object TreeRunner extends Logging { val usage = """ diff --git a/src/main/scala/ml/tree/impurity/Entropy.scala b/src/main/scala/ml/tree/impurity/Entropy.scala new file mode 100644 index 0000000..4c0e7e0 --- /dev/null +++ b/src/main/scala/ml/tree/impurity/Entropy.scala @@ -0,0 +1,18 @@ +package ml.tree.impurity + +object Entropy extends Impurity { + + def log2(x: Double) = scala.math.log(x) / scala.math.log(2) + + def calculate(c0: Double, c1: Double): Double = { + if (c0 == 0 || c1 == 0) { + 0 + } else { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + -(f0 * log2(f0)) - (f1 * log2(f1)) + } + } + + } diff --git a/src/main/scala/ml/tree/impurity/Gini.scala b/src/main/scala/ml/tree/impurity/Gini.scala new file mode 100644 index 0000000..ec349d3 --- /dev/null +++ b/src/main/scala/ml/tree/impurity/Gini.scala @@ -0,0 +1,12 @@ +package ml.tree.impurity + +object Gini extends Impurity { + + def calculate(c0 : Double, c1 : Double): Double = { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + 1 - f0*f0 - f1*f1 + } + + } diff --git a/src/main/scala/ml/tree/impurity/Impurity.scala b/src/main/scala/ml/tree/impurity/Impurity.scala new file mode 100644 index 0000000..550310a --- /dev/null +++ b/src/main/scala/ml/tree/impurity/Impurity.scala @@ -0,0 +1,55 @@ +package ml.tree.impurity + +import ml.tree.{NodeStats, Split} + +trait Impurity { + + def calculateClassificationGain(split: Split, calculations : Map[(Split, String, Double),Long]): Double = { + val leftRddZeroCount = calculations.getOrElse((split,"left",0.0),0L).toDouble; + val rightRddZeroCount = calculations.getOrElse((split,"right",0.0),0L).toDouble; + val leftRddOneCount = calculations.getOrElse((split,"left",1.0),0L).toDouble; + val rightRddOneCount = calculations.getOrElse((split,"right",1.0),0L).toDouble; + val leftRddCount = leftRddZeroCount + leftRddOneCount; + val rightRddCount = rightRddZeroCount + rightRddOneCount; + val totalZeroCount = leftRddZeroCount + rightRddZeroCount; + val totalOneCount = leftRddOneCount + rightRddOneCount; + val totalCount = totalZeroCount + totalOneCount; + val gain = { + if (leftRddCount == 0 || rightRddCount == 0) 0 + else { + val topGini = calculate(totalZeroCount,totalOneCount) + val leftWeight = leftRddCount / totalCount + val leftGini = calculate(leftRddZeroCount,leftRddOneCount) * leftWeight + val rightWeight = rightRddCount / totalCount + val rightGini = calculate(rightRddZeroCount,rightRddOneCount) * rightWeight + topGini - leftGini - rightGini + } + } + gain + } + + def calculateRegressionGain(split: Split, calculations : Map[(Split, String),(Double, Double, Long)], nodeStats : NodeStats): (Double, NodeStats, NodeStats) = { + val topCount = nodeStats.count.get + val leftCount = calculations.getOrElse((split,"left"),(0,0,0L))._3 + val rightCount = calculations.getOrElse((split,"right"),(0,0,0L))._3 + if (leftCount == 0 || rightCount == 0){ + // No gain return values + //println("leftCount = " + leftCount + "rightCount = " + rightCount + " topCount = " + topCount) + (0, new NodeStats, new NodeStats) + } else{ + val topVariance = nodeStats.variance.get + val leftMean = calculations((split,"left"))._1 + val leftVariance = calculations((split,"left"))._2 + val rightMean = calculations((split,"right"))._1 + val rightVariance = calculations((split,"right"))._2 + //TODO: Check and if needed improve these toDouble conversions + val gain = topVariance - ((leftCount.toDouble / topCount) * leftVariance) - ((rightCount.toDouble/topCount) * rightVariance) + (gain, + new NodeStats(mean = Some(leftMean), variance = Some(leftVariance), count = Some(leftCount)), + new NodeStats(mean = Some(rightMean), variance = Some(rightVariance), count = Some(rightCount))) + } + } + + def calculate(c0 : Double, c1 : Double): Double + +} diff --git a/src/main/scala/ml/tree/impurity/Variance.scala b/src/main/scala/ml/tree/impurity/Variance.scala new file mode 100644 index 0000000..8dda877 --- /dev/null +++ b/src/main/scala/ml/tree/impurity/Variance.scala @@ -0,0 +1,7 @@ +package ml.tree.impurity + +import javax.naming.OperationNotSupportedException + +object Variance extends Impurity { + def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate") + } From e29493c86b3c6444d110aed1b85f49d173cd7101 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 12 Oct 2013 21:01:28 -0700 Subject: [PATCH 10/19] reorganized code --- src/main/scala/ml/tree/DecisionTree.scala | 291 ++---------------- src/main/scala/ml/tree/README.md | 2 +- src/main/scala/ml/tree/TreeRunner.scala | 106 +++++++ src/main/scala/ml/tree/TreeUtils.scala | 35 +++ .../scala/ml/tree/impurity/Impurity.scala | 3 +- src/main/scala/ml/tree/node/Node.scala | 42 +++ src/main/scala/ml/tree/node/NodeModel.scala | 56 ++++ src/main/scala/ml/tree/node/NodeStats.scala | 10 + src/main/scala/ml/tree/node/Prediction.scala | 8 + src/main/scala/ml/tree/split/Split.scala | 8 + .../scala/ml/tree/split/SplitPredicate.scala | 8 + .../scala/ml/tree/strategy/Strategy.scala | 3 + 12 files changed, 300 insertions(+), 272 deletions(-) create mode 100644 src/main/scala/ml/tree/TreeRunner.scala create mode 100644 src/main/scala/ml/tree/TreeUtils.scala create mode 100644 src/main/scala/ml/tree/node/Node.scala create mode 100644 src/main/scala/ml/tree/node/NodeModel.scala create mode 100644 src/main/scala/ml/tree/node/NodeStats.scala create mode 100644 src/main/scala/ml/tree/node/Prediction.scala create mode 100644 src/main/scala/ml/tree/split/Split.scala create mode 100644 src/main/scala/ml/tree/split/SplitPredicate.scala create mode 100644 src/main/scala/ml/tree/strategy/Strategy.scala diff --git a/src/main/scala/ml/tree/DecisionTree.scala b/src/main/scala/ml/tree/DecisionTree.scala index 74ccac3..aa08bd0 100644 --- a/src/main/scala/ml/tree/DecisionTree.scala +++ b/src/main/scala/ml/tree/DecisionTree.scala @@ -24,118 +24,12 @@ import org.apache.spark.SparkContext import org.apache.spark.util.StatCounter import org.apache.spark.Logging import ml.tree.impurity.{Variance, Entropy, Gini, Impurity} +import ml.tree.node.{Prediction, NodeStats, NodeModel, Node} +import ml.tree.strategy.Strategy +import ml.tree.split.{SplitPredicate, Split} +import org.apache.spark.broadcast.Broadcast -/* - * Abstract Node class as a template for implementing various types of nodes in the decision tree. - */ -abstract class Node { - - //Method for checking whether the class has any left/right child nodes. - def isLeaf: Boolean - - //Left/Right child nodes - def left: Node - def right: Node - - //Depth of the node from the top node - def depth: Int - - //RDD data as an input to the node - def data: RDD[(Double, Array[Double])] - - //List of split predicates applied to the base RDD thus far - def splitPredicates: List[SplitPredicate] - - // Split to arrive at the node - def splitPredicate: Option[SplitPredicate] - - //Extract model - def extractModel: Option[NodeModel] = { - //Add probability logic - if (!splitPredicate.isEmpty) { Some(new NodeModel(splitPredicate, left.extractModel, right.extractModel, depth, isLeaf, Some(prediction))) } - else { - // Using -1 as depth - Some(new NodeModel(None, None, None, depth, isLeaf, Some(prediction))) - } - } - - //Prediction at the node - def prediction: Prediction -} - -/** - * The decision tree model class that - */ -class NodeModel( - val splitPredicate: Option[SplitPredicate], - val trueNode: Option[NodeModel], - val falseNode: Option[NodeModel], - val depth: Int, - val isLeaf: Boolean, - val prediction: Option[Prediction]) extends ClassificationModel { - - override def toString() = if (!splitPredicate.isEmpty) { - "[" + trueNode.get + "\n" + "[" + "depth = " + depth + ", split predicate = " + this.splitPredicate.get + ", predict = " + this.prediction + "]" + "]\n" + falseNode.get - } else { - "Leaf : " + "depth = " + depth + ", predict = " + prediction + ", isLeaf = " + isLeaf - } - - /** - * Predict values for the given data set using the model trained. - * - * @param testData RDD representing data points to be predicted - * @return RDD[Int] where each entry contains the corresponding prediction - */ - def predict(testData: RDD[Array[Double]]): RDD[Double] = { - testData.map { x => predict(x) } - } - - /** - * Predict values for a single data point using the model trained. - * - * @param testData array representing a single data point - * @return Int prediction from the trained model - */ - def predict(testData: Array[Double]): Double = { - - //TODO: Modify this logic to handle regression - val pred = prediction.get - if (this.isLeaf) { - if (pred.prob > 0.5) 1 else 0 - } else { - val spPred = splitPredicate.get - if (testData(spPred.split.feature) <= spPred.split.threshold) { - trueNode.get.predict(testData) - } else { - falseNode.get.predict(testData) - } - } - } - -} - -/* - * Class used to store the prediction values at each node of the tree. - */ -class Prediction(val prob: Double, val distribution: Map[Double, Double]) extends Serializable { - override def toString = { "probability = " + prob + ", distribution = " + distribution } -} - -/* - * Class for storing splits -- feature index and threshold - */ -case class Split(val feature: Int, val threshold: Double) { - override def toString = "feature = " + feature + ", threshold = " + threshold -} - -/* - * Class for storing the split predicate. - */ -class SplitPredicate(val split: Split, lessThanEqualTo: Boolean = true) extends Serializable { - override def toString = "split = " + split.toString + ", lessThan = " + lessThanEqualTo -} - /* * Class for building the Decision Tree model. Should be used for both classification and regression tree. */ @@ -262,14 +156,14 @@ class DecisionTree ( } abstract class DecisionNode( - val data: RDD[(Double, Array[Double])], - val depth: Int, + val data: RDD[(Double, Array[Double])], + val depth: Int, val splitPredicates: List[SplitPredicate], val nodeStats : NodeStats) extends Node { - + //TODO: Change empty logic val splits = splitPredicates.map(x => x.split) - + //TODO: Think about the merits of doing BFS and removing the parents RDDs from memory instead of doing DFS like below. val (left, right, splitPredicate, isLeaf) = createLeftRightChild() override def toString() = "[" + left + "[" + this.splitPredicate + " prediction = " + this.prediction + "]" + right + "]" @@ -299,7 +193,7 @@ class DecisionTree ( } def findBestSplit(nodeStats: NodeStats): (Split, Double, NodeStats, NodeStats) = { - + //TODO: Also remove splits that are subsets of previous splits val availableSplits = allSplits.value filterNot (split => splits contains split) println("availableSplit count " + availableSplits.size) @@ -307,7 +201,7 @@ class DecisionTree ( strategy match { case Strategy("Classification") => { - + val splitWiseCalculations = data.flatMap(sample => { val label = sample._1 val features = sample._2 @@ -332,7 +226,7 @@ class DecisionTree ( } case Strategy("Regression") => { - + val splitWiseCalculations = data.flatMap(sample => { val label = sample._1 val features = sample._2 @@ -343,23 +237,29 @@ class DecisionTree ( } yield {if (features(featureIndex) <= threshold) ((split, "left"), label) else ((split, "right"), label)} leftOrRight }) - + // Calculate variance for each split - val splitVariancePairs = splitWiseCalculations.groupByKey().map(x => x._1 -> ParVariance.calculateVarianceSize(x._2)).collect + val splitVariancePairs = splitWiseCalculations.groupByKey().map(x => x._1 -> calculateVarianceSize(x._2)).collect //Tuple array to map conversion val gainCalculations = scala.collection.immutable.Map(splitVariancePairs: _*) - + val split_gain_list = for ( split <- availableSplits; (gain, leftNodeStats, rightNodeStats) = impurity.calculateRegressionGain(split, gainCalculations, nodeStats) ) yield (split, gain, leftNodeStats, rightNodeStats) - + val split_gain = split_gain_list.reduce(compareRegressionPair(_, _)) (split_gain._1, split_gain._2,split_gain._3, split_gain._4) } } } + def calculateVarianceSize(seq: Seq[Double]): (Double, Double, Long) = { + val stat = StatCounter(seq) + (stat.mean, stat.variance, stat.count) + } + + } @@ -388,15 +288,6 @@ class DecisionTree ( } -object ParVariance extends Serializable { - - def calculateVarianceSize(seq: Seq[Double]): (Double, Double, Long) = { - val stat = StatCounter(seq) - (stat.mean, stat.variance, stat.count) - } - -} - object DecisionTree { def train( @@ -420,149 +311,9 @@ object DecisionTree { } } -case class Strategy(val name: String) - -class NodeStats( - val gini: Option[Double] = None, - val entropy: Option[Double] = None, - val mean: Option[Double] = None, - val variance: Option[Double] = None, - val count: Option[Long] = None) extends Serializable{ - override def toString = "variance = " + variance + "count = " + count + "mean = " + mean -} - - - - - - - - -object TreeRunner extends Logging { - val usage = """ - Usage: DecisionTreeRunner [slices] --strategy --trainDataDir path --testDataDir path [--maxDepth num] [--impurity ] [--samplingFractionForSplitCalculation num] - """ - - def main(args: Array[String]) { - - if (args.length < 2) { - System.err.println(usage) - System.exit(1) - } - - /**START Experimental*/ - System.setProperty("spark.cores.max", "8") - /**END Experimental*/ - val sc = new SparkContext(args(0), "Decision Tree Runner", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - - val arglist = args.toList.drop(1) - type OptionMap = Map[Symbol, Any] - - def nextOption(map : OptionMap, list: List[String]) : OptionMap = { - def isSwitch(s : String) = (s(0) == '-') - list match { - case Nil => map - case "--strategy" :: string :: tail => nextOption(map ++ Map('strategy -> string), tail) - case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string), tail) - case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), tail) - case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) - case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) - case "--samplingFractionForSplitCalculation" :: string :: tail => nextOption(map ++ Map('samplingFractionForSplitCalculation -> string), tail) - case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) - case option :: tail => println("Unknown option "+option) - exit(1) - } - } - val options = nextOption(Map(),arglist) - println(options) - //TODO: Add check for acceptable string inputs - - val trainData = TreeUtils.loadLabeledData(sc, options.get('trainDataDir).get.toString) - val strategyStr = options.get('strategy).get.toString - val impurityStr = options.getOrElse('impurity,"Gini").toString - val impurity = { - impurityStr match { - case "Gini" => Gini - case "Entropy" => Entropy - case "Variance" => Variance - } - } - val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt - val fraction = options.getOrElse('samplingFractionForSplitCalculation,"1.0").toString.toDouble - - val tree = DecisionTree.train( - input = trainData, - numSplitPredicates = 1000, - strategy = new Strategy(strategyStr), - impurity = impurity, - maxDepth = maxDepth, - fraction = fraction, - sparkContext = sc) - println(tree) - //println("prediction = " + tree.get.predict(Array(1.0, 2.0))) - - val testData = TreeUtils.loadLabeledData(sc, options.get('testDataDir).get.toString) - - - val testError = { - strategyStr match { - case "Classification" => accuracyScore(tree, testData) - case "Regression" => meanSquaredError(tree, testData) - } - } - print("error = " + testError) - - } - - def accuracyScore(tree : Option[ml.tree.NodeModel], data : RDD[(Double, Array[Double])]) : Double = { - if (tree.isEmpty) return 1 //TODO: Throw exception - val correctCount = data.filter(y => tree.get.predict(y._2) == y._1).count() - val count = data.count() - print("correct count = " + correctCount) - print("training data count = " + count) - correctCount.toDouble / count - } - - def meanSquaredError(tree : Option[ml.tree.NodeModel], data : RDD[(Double, Array[Double])]) : Double = { - if (tree.isEmpty) return 1 //TODO: Throw exception - val meanSumOfSquares = data.map(y => (tree.get.predict(y._2) - y._1)*(tree.get.predict(y._2) - y._1)).mean - print("meanSumOfSquares = " + meanSumOfSquares) - meanSumOfSquares - } - -} -/** - * Helper methods to load and save data - * Data format: - * , ... - * where , are feature values in Double and is the corresponding label as Double. - */ -object TreeUtils { - /** - * @param sc SparkContext - * @param dir Directory to the input data files. - * @return An RDD of tuples. For each tuple, the first element is the label, and the second - * element represents the feature values (an array of Double). - */ - def loadLabeledData(sc: SparkContext, dir: String): RDD[(Double, Array[Double])] = { - sc.textFile(dir).map { line => - val parts = line.trim().split(",") - val label = parts(0).toDouble - val features = parts.slice(1,parts.length).map(_.toDouble) - //val features = parts.slice(1, 30).map(_.toDouble) - (label, features) - } - } - def saveLabeledData(data: RDD[(Double, Array[Double])], dir: String) { - val dataStr = data.map(x => x._1 + "," + x._2.mkString(" ")) - dataStr.saveAsTextFile(dir) - } -} diff --git a/src/main/scala/ml/tree/README.md b/src/main/scala/ml/tree/README.md index 93fa950..8607742 100644 --- a/src/main/scala/ml/tree/README.md +++ b/src/main/scala/ml/tree/README.md @@ -2,7 +2,7 @@ Decision tree classifiers are both popular supervised learning algorithms and also building blocks for other ensemble learning algorithms such as random forests, boosting, etc. This document discusses its implementation in the Spark project. #Usage -DecisionTreeRunner [slices] --strategy --trainDataDir path --testDataDir path [--maxDepth num] [--impurity ] [--samplingFractionForSplitCalculation num] +TreeRunner [slices] --strategy --trainDataDir path --testDataDir path [--maxDepth num] [--impurity ] [--samplingFractionForSplitCalculation num] #Example diff --git a/src/main/scala/ml/tree/TreeRunner.scala b/src/main/scala/ml/tree/TreeRunner.scala new file mode 100644 index 0000000..a36b954 --- /dev/null +++ b/src/main/scala/ml/tree/TreeRunner.scala @@ -0,0 +1,106 @@ +package ml.tree + +import org.apache.spark.SparkContext._ +import org.apache.spark.{Logging, SparkContext} +import ml.tree.impurity.{Variance, Entropy, Gini} +import ml.tree.strategy.Strategy + +import ml.tree.node.NodeModel +import org.apache.spark.rdd.RDD + +object TreeRunner extends Logging { + val usage = """ + Usage: TreeRunner [slices] --strategy --trainDataDir path --testDataDir path [--maxDepth num] [--impurity ] [--samplingFractionForSplitCalculation num] + """ + + def main(args: Array[String]) { + + if (args.length < 2) { + System.err.println(usage) + System.exit(1) + } + + /**START Experimental*/ + System.setProperty("spark.cores.max", "8") + /**END Experimental*/ + val sc = new SparkContext(args(0), "Decision Tree Runner", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + + val arglist = args.toList.drop(1) + type OptionMap = Map[Symbol, Any] + + def nextOption(map : OptionMap, list: List[String]) : OptionMap = { + def isSwitch(s : String) = (s(0) == '-') + list match { + case Nil => map + case "--strategy" :: string :: tail => nextOption(map ++ Map('strategy -> string), tail) + case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string), tail) + case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), tail) + case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) + case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) + case "--samplingFractionForSplitCalculation" :: string :: tail => nextOption(map ++ Map('samplingFractionForSplitCalculation -> string), tail) + case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) + case option :: tail => println("Unknown option "+option) + exit(1) + } + } + val options = nextOption(Map(),arglist) + println(options) + //TODO: Add check for acceptable string inputs + + val trainData = TreeUtils.loadLabeledData(sc, options.get('trainDataDir).get.toString) + val strategyStr = options.get('strategy).get.toString + val impurityStr = options.getOrElse('impurity,"Gini").toString + val impurity = { + impurityStr match { + case "Gini" => Gini + case "Entropy" => Entropy + case "Variance" => Variance + } + } + val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt + val fraction = options.getOrElse('samplingFractionForSplitCalculation,"1.0").toString.toDouble + + val tree = DecisionTree.train( + input = trainData, + numSplitPredicates = 1000, + strategy = new Strategy(strategyStr), + impurity = impurity, + maxDepth = maxDepth, + fraction = fraction, + sparkContext = sc) + println(tree) + //println("prediction = " + tree.get.predict(Array(1.0, 2.0))) + + val testData = TreeUtils.loadLabeledData(sc, options.get('testDataDir).get.toString) + + + val testError = { + strategyStr match { + case "Classification" => accuracyScore(tree, testData) + case "Regression" => meanSquaredError(tree, testData) + } + } + print("error = " + testError) + + } + + def accuracyScore(tree : Option[NodeModel], data : RDD[(Double, Array[Double])]) : Double = { + if (tree.isEmpty) return 1 //TODO: Throw exception + val correctCount = data.filter(y => tree.get.predict(y._2) == y._1).count() + val count = data.count() + print("correct count = " + correctCount) + print("training data count = " + count) + correctCount.toDouble / count + } + + def meanSquaredError(tree : Option[NodeModel], data : RDD[(Double, Array[Double])]) : Double = { + if (tree.isEmpty) return 1 //TODO: Throw exception + val meanSumOfSquares = data.map(y => (tree.get.predict(y._2) - y._1)*(tree.get.predict(y._2) - y._1)).mean() + print("meanSumOfSquares = " + meanSumOfSquares) + meanSumOfSquares + } + + +} diff --git a/src/main/scala/ml/tree/TreeUtils.scala b/src/main/scala/ml/tree/TreeUtils.scala new file mode 100644 index 0000000..2cacdbe --- /dev/null +++ b/src/main/scala/ml/tree/TreeUtils.scala @@ -0,0 +1,35 @@ +package ml.tree + +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD + +/** + * Helper methods to load and save data + * Data format: + * , ... + * where , are feature values in Double and is the corresponding label as Double. + */ +object TreeUtils { + + /** + * @param sc SparkContext + * @param dir Directory to the input data files. + * @return An RDD of tuples. For each tuple, the first element is the label, and the second + * element represents the feature values (an array of Double). + */ + def loadLabeledData(sc: SparkContext, dir: String): RDD[(Double, Array[Double])] = { + sc.textFile(dir).map { line => + val parts = line.trim().split(",") + val label = parts(0).toDouble + val features = parts.slice(1,parts.length).map(_.toDouble) + //val features = parts.slice(1, 30).map(_.toDouble) + (label, features) + } + } + + def saveLabeledData(data: RDD[(Double, Array[Double])], dir: String) { + val dataStr = data.map(x => x._1 + "," + x._2.mkString(" ")) + dataStr.saveAsTextFile(dir) + } + +} diff --git a/src/main/scala/ml/tree/impurity/Impurity.scala b/src/main/scala/ml/tree/impurity/Impurity.scala index 550310a..8b9095b 100644 --- a/src/main/scala/ml/tree/impurity/Impurity.scala +++ b/src/main/scala/ml/tree/impurity/Impurity.scala @@ -1,6 +1,7 @@ package ml.tree.impurity -import ml.tree.{NodeStats, Split} +import ml.tree.node.NodeStats +import ml.tree.split.Split trait Impurity { diff --git a/src/main/scala/ml/tree/node/Node.scala b/src/main/scala/ml/tree/node/Node.scala new file mode 100644 index 0000000..b99b216 --- /dev/null +++ b/src/main/scala/ml/tree/node/Node.scala @@ -0,0 +1,42 @@ +package ml.tree.node + +import org.apache.spark.rdd.RDD +import ml.tree.split.SplitPredicate + +/* + * Node trait as a template for implementing various types of nodes in the decision tree. + */ +trait Node { + + //Method for checking whether the class has any left/right child nodes. + def isLeaf: Boolean + + //Left/Right child nodes + def left: Node + def right: Node + + //Depth of the node from the top node + def depth: Int + + //RDD data as an input to the node + def data: RDD[(Double, Array[Double])] + + //List of split predicates applied to the base RDD thus far + def splitPredicates: List[SplitPredicate] + + // Split to arrive at the node + def splitPredicate: Option[SplitPredicate] + + //Extract model + def extractModel: Option[NodeModel] = { + //Add probability logic + if (!splitPredicate.isEmpty) { Some(new NodeModel(splitPredicate, left.extractModel, right.extractModel, depth, isLeaf, Some(prediction))) } + else { + // Using -1 as depth + Some(new NodeModel(None, None, None, depth, isLeaf, Some(prediction))) + } + } + + //Prediction at the node + def prediction: Prediction +} diff --git a/src/main/scala/ml/tree/node/NodeModel.scala b/src/main/scala/ml/tree/node/NodeModel.scala new file mode 100644 index 0000000..0f41000 --- /dev/null +++ b/src/main/scala/ml/tree/node/NodeModel.scala @@ -0,0 +1,56 @@ +package ml.tree.node + +import org.apache.spark.mllib.classification.ClassificationModel +import org.apache.spark.rdd.RDD +import ml.tree.split.SplitPredicate + +/** + * The decision tree model class that + */ +class NodeModel( + val splitPredicate: Option[SplitPredicate], + val trueNode: Option[NodeModel], + val falseNode: Option[NodeModel], + val depth: Int, + val isLeaf: Boolean, + val prediction: Option[Prediction]) extends ClassificationModel { + + override def toString() = if (!splitPredicate.isEmpty) { + "[" + trueNode.get + "\n" + "[" + "depth = " + depth + ", split predicate = " + this.splitPredicate.get + ", predict = " + this.prediction + "]" + "]\n" + falseNode.get + } else { + "Leaf : " + "depth = " + depth + ", predict = " + prediction + ", isLeaf = " + isLeaf + } + + /** + * Predict values for the given data set using the model trained. + * + * @param testData RDD representing data points to be predicted + * @return RDD[Int] where each entry contains the corresponding prediction + */ + def predict(testData: RDD[Array[Double]]): RDD[Double] = { + testData.map { x => predict(x) } + } + + /** + * Predict values for a single data point using the model trained. + * + * @param testData array representing a single data point + * @return Int prediction from the trained model + */ + def predict(testData: Array[Double]): Double = { + + //TODO: Modify this logic to handle regression + val pred = prediction.get + if (this.isLeaf) { + if (pred.prob > 0.5) 1 else 0 + } else { + val spPred = splitPredicate.get + if (testData(spPred.split.feature) <= spPred.split.threshold) { + trueNode.get.predict(testData) + } else { + falseNode.get.predict(testData) + } + } + } + +} diff --git a/src/main/scala/ml/tree/node/NodeStats.scala b/src/main/scala/ml/tree/node/NodeStats.scala new file mode 100644 index 0000000..7387e90 --- /dev/null +++ b/src/main/scala/ml/tree/node/NodeStats.scala @@ -0,0 +1,10 @@ +package ml.tree.node + +class NodeStats( + val gini: Option[Double] = None, + val entropy: Option[Double] = None, + val mean: Option[Double] = None, + val variance: Option[Double] = None, + val count: Option[Long] = None) extends Serializable{ + override def toString = "variance = " + variance + "count = " + count + "mean = " + mean +} diff --git a/src/main/scala/ml/tree/node/Prediction.scala b/src/main/scala/ml/tree/node/Prediction.scala new file mode 100644 index 0000000..7417e20 --- /dev/null +++ b/src/main/scala/ml/tree/node/Prediction.scala @@ -0,0 +1,8 @@ +package ml.tree.node + +/* + * Class used to store the prediction values at each node of the tree. + */ +class Prediction(val prob: Double, val distribution: Map[Double, Double]) extends Serializable { + override def toString = { "probability = " + prob + ", distribution = " + distribution } +} diff --git a/src/main/scala/ml/tree/split/Split.scala b/src/main/scala/ml/tree/split/Split.scala new file mode 100644 index 0000000..69db788 --- /dev/null +++ b/src/main/scala/ml/tree/split/Split.scala @@ -0,0 +1,8 @@ +package ml.tree.split + +/* + * Class for storing splits -- feature index and threshold + */ +case class Split(val feature: Int, val threshold: Double) { + override def toString = "feature = " + feature + ", threshold = " + threshold +} diff --git a/src/main/scala/ml/tree/split/SplitPredicate.scala b/src/main/scala/ml/tree/split/SplitPredicate.scala new file mode 100644 index 0000000..63a1701 --- /dev/null +++ b/src/main/scala/ml/tree/split/SplitPredicate.scala @@ -0,0 +1,8 @@ +package ml.tree.split + +/* + * Class for storing the split predicate. + */ +class SplitPredicate(val split: Split, lessThanEqualTo: Boolean = true) extends Serializable { + override def toString = "split = " + split.toString + ", lessThan = " + lessThanEqualTo +} diff --git a/src/main/scala/ml/tree/strategy/Strategy.scala b/src/main/scala/ml/tree/strategy/Strategy.scala new file mode 100644 index 0000000..9a0e144 --- /dev/null +++ b/src/main/scala/ml/tree/strategy/Strategy.scala @@ -0,0 +1,3 @@ +package ml.tree.strategy + +case class Strategy(val name: String) From 6047ed88c57b88ea764d74e84ffc60f7a1fadff3 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 12 Oct 2013 23:20:12 -0700 Subject: [PATCH 11/19] refactored decison nodes into separate package --- src/main/scala/ml/tree/DecisionTree.scala | 192 +---------------- .../scala/ml/tree/node/decisionNodes.scala | 201 ++++++++++++++++++ 2 files changed, 208 insertions(+), 185 deletions(-) create mode 100644 src/main/scala/ml/tree/node/decisionNodes.scala diff --git a/src/main/scala/ml/tree/DecisionTree.scala b/src/main/scala/ml/tree/DecisionTree.scala index aa08bd0..d012f7a 100644 --- a/src/main/scala/ml/tree/DecisionTree.scala +++ b/src/main/scala/ml/tree/DecisionTree.scala @@ -24,10 +24,13 @@ import org.apache.spark.SparkContext import org.apache.spark.util.StatCounter import org.apache.spark.Logging import ml.tree.impurity.{Variance, Entropy, Gini, Impurity} -import ml.tree.node.{Prediction, NodeStats, NodeModel, Node} import ml.tree.strategy.Strategy import ml.tree.split.{SplitPredicate, Split} import org.apache.spark.broadcast.Broadcast +import scala.Some +import ml.tree.strategy.Strategy +import ml.tree.split.Split +import ml.tree.node._ /* @@ -90,198 +93,16 @@ class DecisionTree ( new Split(featureIndex, valueAtRDDIndex(featureIndex, index)) } - /* - * Empty Node class used to terminate leaf nodes - */ - private class LeafNode(val data: RDD[(Double, Array[Double])]) extends Node { - def isLeaf = true - def left = throw new OperationNotSupportedException("EmptyNode.left") - def right = throw new OperationNotSupportedException("EmptyNode.right") - def depth = throw new OperationNotSupportedException("EmptyNode.depth") - def splitPredicates = throw new OperationNotSupportedException("EmptyNode.splitPredicates") - def splitPredicate = throw new OperationNotSupportedException("EmptyNode.splitPredicate") - override def toString() = "Empty" - val prediction: Prediction = { - val countZero: Double = data.filter(x => (x._1 == 0.0)).count - val countOne: Double = data.filter(x => (x._1 == 1.0)).count - val countTotal: Double = countZero + countOne - new Prediction(countOne / countTotal, Map(0.0 -> countZero, 1.0 -> countOne)) - } - } - - /* - * Top node for building a classification tree - */ - private class TopClassificationNode extends ClassificationNode(input.cache, 0, List[SplitPredicate](), new NodeStats) { - override def toString() = "[" + left + "[" + "TopNode" + "]" + right + "]" - } - - /* - * Class for each node in the classification tree - */ - private class ClassificationNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats) - extends DecisionNode(data, depth, splitPredicates, nodeStats) { - - // Prediction at each classification node - val prediction: Prediction = { - val countZero: Double = data.filter(x => (x._1 == 0.0)).count - val countOne: Double = data.filter(x => (x._1 == 1.0)).count - val countTotal: Double = countZero + countOne - new Prediction(countOne / countTotal, Map(0.0 -> countZero, 1.0 -> countOne)) - } - - //Static factory method. Put it in a better location. - def createNode(anyData: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats) = new ClassificationNode(anyData, depth, splitPredicates, nodeStats) - - } - - /* - * Top node for building a regression tree - */ - private class TopRegressionNode(nodeStats : NodeStats) extends RegressionNode(input.cache, 0, List[SplitPredicate](), nodeStats) { - override def toString() = "[" + left + "[" + "TopNode" + "]" + right + "]" - } - - /* - * Class for each node in the regression tree - */ - private class RegressionNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats) - extends DecisionNode(data, depth, splitPredicates, nodeStats) { - - // Prediction at each regression node - val prediction: Prediction = new Prediction(data.map(_._1).mean, Map()) - - //Static factory method. Put it in a better location. - def createNode(anyData: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats) = new RegressionNode(anyData, depth, splitPredicates, nodeStats) - } - - abstract class DecisionNode( - val data: RDD[(Double, Array[Double])], - val depth: Int, - val splitPredicates: List[SplitPredicate], - val nodeStats : NodeStats) extends Node { - - //TODO: Change empty logic - val splits = splitPredicates.map(x => x.split) - - //TODO: Think about the merits of doing BFS and removing the parents RDDs from memory instead of doing DFS like below. - val (left, right, splitPredicate, isLeaf) = createLeftRightChild() - override def toString() = "[" + left + "[" + this.splitPredicate + " prediction = " + this.prediction + "]" + right + "]" - def createNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats): DecisionNode - def createLeftRightChild(): (Node, Node, Option[SplitPredicate], Boolean) = { - if (depth > maxDepth) { - (new LeafNode(data), new LeafNode(data), None, true) - } else { - println("split count " + splits.length) - val split_gain = findBestSplit(nodeStats) - val (split, gain, leftNodeStats, rightNodeStats) = split_gain - println("Selected split = " + split + " with gain = " + gain, "left node stats = " + leftNodeStats + " right node stats = " + rightNodeStats) - if (split_gain._2 > 0) { - println("creating new nodes at depth = " + depth) - val leftPredicate = new SplitPredicate(split, true) - val rightPredicate = new SplitPredicate(split, false) - val leftData = data.filter(sample => sample._2(leftPredicate.split.feature) <= leftPredicate.split.threshold).cache - val rightData = data.filter(sample => sample._2(rightPredicate.split.feature) > rightPredicate.split.threshold).cache - val leftNode = if (leftData.count != 0) createNode(leftData, depth + 1, splitPredicates ::: List(leftPredicate), leftNodeStats) else new LeafNode(data) - val rightNode = if (rightData.count != 0) createNode(rightData, depth + 1, splitPredicates ::: List(rightPredicate), rightNodeStats) else new LeafNode(data) - (leftNode, rightNode, Some(leftPredicate), false) - } else { - println("not creating more child nodes since gain is not greater than 0") - (new LeafNode(data), new LeafNode(data), None, true) - } - } - } - - def findBestSplit(nodeStats: NodeStats): (Split, Double, NodeStats, NodeStats) = { - - //TODO: Also remove splits that are subsets of previous splits - val availableSplits = allSplits.value filterNot (split => splits contains split) - println("availableSplit count " + availableSplits.size) - //availableSplits.map(split1 => (split1, impurity.calculateGain(split1, data))).reduce(comparePair(_, _)) - - strategy match { - case Strategy("Classification") => { - - val splitWiseCalculations = data.flatMap(sample => { - val label = sample._1 - val features = sample._2 - val leftOrRight = for { - split <- availableSplits.toSeq - featureIndex = split.feature - threshold = split.threshold - } yield { if (features(featureIndex) <= threshold) (split, "left", label) else (split, "right", label) } - leftOrRight - }).map(k => (k, 1)) - - val gainCalculations = splitWiseCalculations.countByKey() - .toMap //TODO: Hack to go from mutable to immutable map. Clean this up if needed. - - val split_gain_list = for ( - split <- availableSplits; - gain = impurity.calculateClassificationGain(split, gainCalculations) - ) yield (split, gain) - - val split_gain = split_gain_list.reduce(comparePair(_, _)) - (split_gain._1, split_gain._2, new NodeStats, new NodeStats) - - } - case Strategy("Regression") => { - - val splitWiseCalculations = data.flatMap(sample => { - val label = sample._1 - val features = sample._2 - val leftOrRight = for { - split <- availableSplits.toSeq - featureIndex = split.feature - threshold = split.threshold - } yield {if (features(featureIndex) <= threshold) ((split, "left"), label) else ((split, "right"), label)} - leftOrRight - }) - - // Calculate variance for each split - val splitVariancePairs = splitWiseCalculations.groupByKey().map(x => x._1 -> calculateVarianceSize(x._2)).collect - //Tuple array to map conversion - val gainCalculations = scala.collection.immutable.Map(splitVariancePairs: _*) - - val split_gain_list = for ( - split <- availableSplits; - (gain, leftNodeStats, rightNodeStats) = impurity.calculateRegressionGain(split, gainCalculations, nodeStats) - ) yield (split, gain, leftNodeStats, rightNodeStats) - - val split_gain = split_gain_list.reduce(compareRegressionPair(_, _)) - (split_gain._1, split_gain._2,split_gain._3, split_gain._4) - } - } - } - - def calculateVarianceSize(seq: Seq[Double]): (Double, Double, Long) = { - val stat = StatCounter(seq) - (stat.mean, stat.variance, stat.count) - } - - - } - - - def comparePair(x: (Split, Double), y: (Split, Double)): (Split, Double) = { - if (x._2 > y._2) x else y - } - - def compareRegressionPair(x: (Split, Double, NodeStats, NodeStats), y: (Split, Double, NodeStats, NodeStats)): (Split, Double, NodeStats, NodeStats) = { - if (x._2 > y._2) x else y - } - - def buildTree(): Node = { strategy match { - case Strategy("Classification") => new TopClassificationNode() + case Strategy("Classification") => new TopClassificationNode(input, allSplits, impurity, strategy, maxDepth) case Strategy("Regression") => { val count = input.count //TODO: calculate mean and variance together val variance = input.map(x => x._1).variance val mean = input.map(x => x._1).mean val nodeStats = new NodeStats(count = Some(count), variance = Some(variance), mean = Some(mean)) - new TopRegressionNode(nodeStats) + new TopRegressionNode(input, nodeStats,allSplits, impurity, strategy, maxDepth) } } } @@ -317,3 +138,4 @@ object DecisionTree { + diff --git a/src/main/scala/ml/tree/node/decisionNodes.scala b/src/main/scala/ml/tree/node/decisionNodes.scala new file mode 100644 index 0000000..7c45150 --- /dev/null +++ b/src/main/scala/ml/tree/node/decisionNodes.scala @@ -0,0 +1,201 @@ +package ml.tree.node + +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD +import ml.tree.split.{Split, SplitPredicate} +import org.apache.spark.broadcast.Broadcast +import ml.tree.impurity.Impurity +import ml.tree.strategy.Strategy +import org.apache.spark.util.StatCounter +import javax.naming.OperationNotSupportedException + +abstract class DecisionNode( + val data: RDD[(Double, Array[Double])], + val depth: Int, + val splitPredicates: List[SplitPredicate], + val nodeStats : NodeStats, + val allSplits : Broadcast[Set[Split]], + val impurity : Impurity, + val strategy: Strategy, + val maxDepth : Int) extends Node { + + //TODO: Change empty logic + val splits = splitPredicates.map(x => x.split) + + //TODO: Think about the merits of doing BFS and removing the parents RDDs from memory instead of doing DFS like below. + val (left, right, splitPredicate, isLeaf) = createLeftRightChild() + override def toString() = "[" + left + "[" + this.splitPredicate + " prediction = " + this.prediction + "]" + right + "]" + def createNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats): DecisionNode + def createLeftRightChild(): (Node, Node, Option[SplitPredicate], Boolean) = { + if (depth > maxDepth) { + (new LeafNode(data), new LeafNode(data), None, true) + } else { + println("split count " + splits.length) + val split_gain = findBestSplit(nodeStats) + val (split, gain, leftNodeStats, rightNodeStats) = split_gain + println("Selected split = " + split + " with gain = " + gain, "left node stats = " + leftNodeStats + " right node stats = " + rightNodeStats) + if (split_gain._2 > 0) { + println("creating new nodes at depth = " + depth) + val leftPredicate = new SplitPredicate(split, true) + val rightPredicate = new SplitPredicate(split, false) + val leftData = data.filter(sample => sample._2(leftPredicate.split.feature) <= leftPredicate.split.threshold).cache + val rightData = data.filter(sample => sample._2(rightPredicate.split.feature) > rightPredicate.split.threshold).cache + val leftNode = if (leftData.count != 0) createNode(leftData, depth + 1, splitPredicates ::: List(leftPredicate), leftNodeStats) else new LeafNode(data) + val rightNode = if (rightData.count != 0) createNode(rightData, depth + 1, splitPredicates ::: List(rightPredicate), rightNodeStats) else new LeafNode(data) + (leftNode, rightNode, Some(leftPredicate), false) + } else { + println("not creating more child nodes since gain is not greater than 0") + (new LeafNode(data), new LeafNode(data), None, true) + } + } + } + + def comparePair(x: (Split, Double), y: (Split, Double)): (Split, Double) = { + if (x._2 > y._2) x else y + } + + def compareRegressionPair(x: (Split, Double, NodeStats, NodeStats), y: (Split, Double, NodeStats, NodeStats)): (Split, Double, NodeStats, NodeStats) = { + if (x._2 > y._2) x else y + } + + + def findBestSplit(nodeStats: NodeStats): (Split, Double, NodeStats, NodeStats) = { + + //TODO: Also remove splits that are subsets of previous splits + val availableSplits = allSplits.value filterNot (split => splits contains split) + println("availableSplit count " + availableSplits.size) + //availableSplits.map(split1 => (split1, impurity.calculateGain(split1, data))).reduce(comparePair(_, _)) + + strategy match { + case Strategy("Classification") => { + + val splitWiseCalculations = data.flatMap(sample => { + val label = sample._1 + val features = sample._2 + val leftOrRight = for { + split <- availableSplits.toSeq + featureIndex = split.feature + threshold = split.threshold + } yield { if (features(featureIndex) <= threshold) (split, "left", label) else (split, "right", label) } + leftOrRight + }).map(k => (k, 1)) + + val gainCalculations = splitWiseCalculations.countByKey() + .toMap //TODO: Hack to go from mutable to immutable map. Clean this up if needed. + + val split_gain_list = for ( + split <- availableSplits; + gain = impurity.calculateClassificationGain(split, gainCalculations) + ) yield (split, gain) + + val split_gain = split_gain_list.reduce(comparePair(_, _)) + (split_gain._1, split_gain._2, new NodeStats, new NodeStats) + + } + case Strategy("Regression") => { + + val splitWiseCalculations = data.flatMap(sample => { + val label = sample._1 + val features = sample._2 + val leftOrRight = for { + split <- availableSplits.toSeq + featureIndex = split.feature + threshold = split.threshold + } yield {if (features(featureIndex) <= threshold) ((split, "left"), label) else ((split, "right"), label)} + leftOrRight + }) + + // Calculate variance for each split + val splitVariancePairs = splitWiseCalculations.groupByKey().map(x => x._1 -> calculateVarianceSize(x._2)).collect + //Tuple array to map conversion + val gainCalculations = scala.collection.immutable.Map(splitVariancePairs: _*) + + val split_gain_list = for ( + split <- availableSplits; + (gain, leftNodeStats, rightNodeStats) = impurity.calculateRegressionGain(split, gainCalculations, nodeStats) + ) yield (split, gain, leftNodeStats, rightNodeStats) + + val split_gain = split_gain_list.reduce(compareRegressionPair(_, _)) + (split_gain._1, split_gain._2,split_gain._3, split_gain._4) + } + } + } + + def calculateVarianceSize(seq: Seq[Double]): (Double, Double, Long) = { + val stat = StatCounter(seq) + (stat.mean, stat.variance, stat.count) + } + + +} + + +/* + * Top node for building a classification tree + */ +class TopClassificationNode(input: RDD[(Double, Array[Double])], allSplits : Broadcast[Set[Split]], impurity : Impurity,strategy: Strategy, maxDepth : Int) extends ClassificationNode(input.cache, 0, List[SplitPredicate](), new NodeStats, allSplits, impurity, strategy, maxDepth) { + override def toString() = "[" + left + "[" + "TopNode" + "]" + right + "]" +} + +/* + * Class for each node in the classification tree + */ +class ClassificationNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats, allSplits : Broadcast[Set[Split]], impurity : Impurity, strategy: Strategy, maxDepth : Int) + extends DecisionNode(data, depth, splitPredicates, nodeStats, allSplits, impurity, strategy, maxDepth) { + + // Prediction at each classification node + val prediction: Prediction = { + val countZero: Double = data.filter(x => (x._1 == 0.0)).count + val countOne: Double = data.filter(x => (x._1 == 1.0)).count + val countTotal: Double = countZero + countOne + new Prediction(countOne / countTotal, Map(0.0 -> countZero, 1.0 -> countOne)) + } + + //Static factory method. Put it in a better location. + def createNode(anyData: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats) + = new ClassificationNode(anyData, depth, splitPredicates, nodeStats, allSplits, impurity, strategy, maxDepth) + +} + +/* + * Top node for building a regression tree + */ +class TopRegressionNode(input: RDD[(Double, Array[Double])], nodeStats : NodeStats, allSplits : Broadcast[Set[Split]], impurity : Impurity,strategy: Strategy, maxDepth : Int) extends RegressionNode(input.cache, 0, List[SplitPredicate](), nodeStats, allSplits, impurity, strategy, maxDepth) { + override def toString() = "[" + left + "[" + "TopNode" + "]" + right + "]" +} + +/* + * Class for each node in the regression tree + */ +class RegressionNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats, allSplits : Broadcast[Set[Split]], impurity : Impurity,strategy: Strategy, maxDepth : Int) + extends DecisionNode(data, depth, splitPredicates, nodeStats, allSplits, impurity, strategy, maxDepth) { + + // Prediction at each regression node + val prediction: Prediction = new Prediction(data.map(_._1).mean, Map()) + + //Static factory method. Put it in a better location. + def createNode(anyData: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats) + = new RegressionNode(anyData, depth, splitPredicates, nodeStats, allSplits, impurity, strategy, maxDepth) +} + +/* + * Empty Node class used to terminate leaf nodes + */ +class LeafNode(val data: RDD[(Double, Array[Double])]) extends Node { + def isLeaf = true + def left = throw new OperationNotSupportedException("EmptyNode.left") + def right = throw new OperationNotSupportedException("EmptyNode.right") + def depth = throw new OperationNotSupportedException("EmptyNode.depth") + def splitPredicates = throw new OperationNotSupportedException("EmptyNode.splitPredicates") + def splitPredicate = throw new OperationNotSupportedException("EmptyNode.splitPredicate") + override def toString() = "Empty" + val prediction: Prediction = { + val countZero: Double = data.filter(x => (x._1 == 0.0)).count + val countOne: Double = data.filter(x => (x._1 == 1.0)).count + val countTotal: Double = countZero + countOne + new Prediction(countOne / countTotal, Map(0.0 -> countZero, 1.0 -> countOne)) + } +} + + + From 2a1185be448ac1de1623e5482b37a504343fcbd3 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 12 Oct 2013 23:53:54 -0700 Subject: [PATCH 12/19] making variance serializable --- src/main/scala/ml/tree/TreeUtils.scala | 2 + .../scala/ml/tree/node/decisionNodes.scala | 47 +++++++++++++------ 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/src/main/scala/ml/tree/TreeUtils.scala b/src/main/scala/ml/tree/TreeUtils.scala index 2cacdbe..29abc4e 100644 --- a/src/main/scala/ml/tree/TreeUtils.scala +++ b/src/main/scala/ml/tree/TreeUtils.scala @@ -3,6 +3,8 @@ package ml.tree import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD + +//TODO: Deprecate this when we find something equivalent in ml utils /** * Helper methods to load and save data * Data format: diff --git a/src/main/scala/ml/tree/node/decisionNodes.scala b/src/main/scala/ml/tree/node/decisionNodes.scala index 7c45150..3ba767e 100644 --- a/src/main/scala/ml/tree/node/decisionNodes.scala +++ b/src/main/scala/ml/tree/node/decisionNodes.scala @@ -13,19 +13,22 @@ abstract class DecisionNode( val data: RDD[(Double, Array[Double])], val depth: Int, val splitPredicates: List[SplitPredicate], - val nodeStats : NodeStats, - val allSplits : Broadcast[Set[Split]], - val impurity : Impurity, + val nodeStats: NodeStats, + val allSplits: Broadcast[Set[Split]], + val impurity: Impurity, val strategy: Strategy, - val maxDepth : Int) extends Node { + val maxDepth: Int) extends Node { //TODO: Change empty logic val splits = splitPredicates.map(x => x.split) //TODO: Think about the merits of doing BFS and removing the parents RDDs from memory instead of doing DFS like below. val (left, right, splitPredicate, isLeaf) = createLeftRightChild() + override def toString() = "[" + left + "[" + this.splitPredicate + " prediction = " + this.prediction + "]" + right + "]" - def createNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats): DecisionNode + + def createNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats: NodeStats): DecisionNode + def createLeftRightChild(): (Node, Node, Option[SplitPredicate], Boolean) = { if (depth > maxDepth) { (new LeafNode(data), new LeafNode(data), None, true) @@ -76,7 +79,9 @@ abstract class DecisionNode( split <- availableSplits.toSeq featureIndex = split.feature threshold = split.threshold - } yield { if (features(featureIndex) <= threshold) (split, "left", label) else (split, "right", label) } + } yield { + if (features(featureIndex) <= threshold) (split, "left", label) else (split, "right", label) + } leftOrRight }).map(k => (k, 1)) @@ -101,12 +106,17 @@ abstract class DecisionNode( split <- availableSplits.toSeq featureIndex = split.feature threshold = split.threshold - } yield {if (features(featureIndex) <= threshold) ((split, "left"), label) else ((split, "right"), label)} + } yield { + if (features(featureIndex) <= threshold) ((split, "left"), label) else ((split, "right"), label) + } leftOrRight }) // Calculate variance for each split - val splitVariancePairs = splitWiseCalculations.groupByKey().map(x => x._1 -> calculateVarianceSize(x._2)).collect + val splitVariancePairs = splitWiseCalculations + .groupByKey() + .map(x => x._1 -> {val stat = StatCounter(x._2); (stat.mean, stat.variance, stat.count)}) + .collect //Tuple array to map conversion val gainCalculations = scala.collection.immutable.Map(splitVariancePairs: _*) @@ -116,7 +126,7 @@ abstract class DecisionNode( ) yield (split, gain, leftNodeStats, rightNodeStats) val split_gain = split_gain_list.reduce(compareRegressionPair(_, _)) - (split_gain._1, split_gain._2,split_gain._3, split_gain._4) + (split_gain._1, split_gain._2, split_gain._3, split_gain._4) } } } @@ -133,14 +143,14 @@ abstract class DecisionNode( /* * Top node for building a classification tree */ -class TopClassificationNode(input: RDD[(Double, Array[Double])], allSplits : Broadcast[Set[Split]], impurity : Impurity,strategy: Strategy, maxDepth : Int) extends ClassificationNode(input.cache, 0, List[SplitPredicate](), new NodeStats, allSplits, impurity, strategy, maxDepth) { +class TopClassificationNode(input: RDD[(Double, Array[Double])], allSplits: Broadcast[Set[Split]], impurity: Impurity, strategy: Strategy, maxDepth: Int) extends ClassificationNode(input.cache, 0, List[SplitPredicate](), new NodeStats, allSplits, impurity, strategy, maxDepth) { override def toString() = "[" + left + "[" + "TopNode" + "]" + right + "]" } /* * Class for each node in the classification tree */ -class ClassificationNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats, allSplits : Broadcast[Set[Split]], impurity : Impurity, strategy: Strategy, maxDepth : Int) +class ClassificationNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats: NodeStats, allSplits: Broadcast[Set[Split]], impurity: Impurity, strategy: Strategy, maxDepth: Int) extends DecisionNode(data, depth, splitPredicates, nodeStats, allSplits, impurity, strategy, maxDepth) { // Prediction at each classification node @@ -152,7 +162,7 @@ class ClassificationNode(data: RDD[(Double, Array[Double])], depth: Int, splitPr } //Static factory method. Put it in a better location. - def createNode(anyData: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats) + def createNode(anyData: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats: NodeStats) = new ClassificationNode(anyData, depth, splitPredicates, nodeStats, allSplits, impurity, strategy, maxDepth) } @@ -160,21 +170,21 @@ class ClassificationNode(data: RDD[(Double, Array[Double])], depth: Int, splitPr /* * Top node for building a regression tree */ -class TopRegressionNode(input: RDD[(Double, Array[Double])], nodeStats : NodeStats, allSplits : Broadcast[Set[Split]], impurity : Impurity,strategy: Strategy, maxDepth : Int) extends RegressionNode(input.cache, 0, List[SplitPredicate](), nodeStats, allSplits, impurity, strategy, maxDepth) { +class TopRegressionNode(input: RDD[(Double, Array[Double])], nodeStats: NodeStats, allSplits: Broadcast[Set[Split]], impurity: Impurity, strategy: Strategy, maxDepth: Int) extends RegressionNode(input.cache, 0, List[SplitPredicate](), nodeStats, allSplits, impurity, strategy, maxDepth) { override def toString() = "[" + left + "[" + "TopNode" + "]" + right + "]" } /* * Class for each node in the regression tree */ -class RegressionNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats, allSplits : Broadcast[Set[Split]], impurity : Impurity,strategy: Strategy, maxDepth : Int) +class RegressionNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats: NodeStats, allSplits: Broadcast[Set[Split]], impurity: Impurity, strategy: Strategy, maxDepth: Int) extends DecisionNode(data, depth, splitPredicates, nodeStats, allSplits, impurity, strategy, maxDepth) { // Prediction at each regression node val prediction: Prediction = new Prediction(data.map(_._1).mean, Map()) //Static factory method. Put it in a better location. - def createNode(anyData: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats : NodeStats) + def createNode(anyData: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats: NodeStats) = new RegressionNode(anyData, depth, splitPredicates, nodeStats, allSplits, impurity, strategy, maxDepth) } @@ -183,12 +193,19 @@ class RegressionNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredic */ class LeafNode(val data: RDD[(Double, Array[Double])]) extends Node { def isLeaf = true + def left = throw new OperationNotSupportedException("EmptyNode.left") + def right = throw new OperationNotSupportedException("EmptyNode.right") + def depth = throw new OperationNotSupportedException("EmptyNode.depth") + def splitPredicates = throw new OperationNotSupportedException("EmptyNode.splitPredicates") + def splitPredicate = throw new OperationNotSupportedException("EmptyNode.splitPredicate") + override def toString() = "Empty" + val prediction: Prediction = { val countZero: Double = data.filter(x => (x._1 == 0.0)).count val countOne: Double = data.filter(x => (x._1 == 1.0)).count From b2447a8f1673c1420283b28fd38e1e8a58b00cd3 Mon Sep 17 00:00:00 2001 From: manishamde Date: Sun, 13 Oct 2013 00:21:40 -0700 Subject: [PATCH 13/19] fixing usage and example --- src/main/scala/ml/tree/README.md | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/main/scala/ml/tree/README.md b/src/main/scala/ml/tree/README.md index 8607742..56a6f9f 100644 --- a/src/main/scala/ml/tree/README.md +++ b/src/main/scala/ml/tree/README.md @@ -2,11 +2,23 @@ Decision tree classifiers are both popular supervised learning algorithms and also building blocks for other ensemble learning algorithms such as random forests, boosting, etc. This document discusses its implementation in the Spark project. #Usage -TreeRunner [slices] --strategy --trainDataDir path --testDataDir path [--maxDepth num] [--impurity ] [--samplingFractionForSplitCalculation num] - +``` +ml.tree.TreeRunner +[slices] +--strategy +--trainDataDir path +--testDataDir path +[--maxDepth num] +[--impurity ] +[--samplingFractionForSplitCalculation num] +``` #Example -sbt/sbt "run-main ml.tree.TreeRunner local[2] --strategy Classification --trainDataDir ../train_data --testDataDir ../test_data --maxDepth 1 --impurity Gini --samplingFractionForSplitCalculation 1" +``` +sbt/sbt "run-main ml.tree.TreeRunner local[2] --strategy Classification +--trainDataDir ../train_data --testDataDir ../test_data +--maxDepth 1 --impurity Gini --samplingFractionForSplitCalculation 1 +``` This command will create a decision tree model using the training data in the *trainDataDir* and calculate test error using the data in the *testDataDir*. The mis-classification error is calculated for a Classification *strategy* and mean squared error is calculated for the Regression *strategy*. From 68ad6c8844f67edea5d76a119279bea823e63d4b Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 19 Oct 2013 15:40:04 -0700 Subject: [PATCH 14/19] drastic speedup of split calculation --- src/main/scala/ml/tree/DecisionTree.scala | 44 +++++++++++------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/main/scala/ml/tree/DecisionTree.scala b/src/main/scala/ml/tree/DecisionTree.scala index d012f7a..24d10a8 100644 --- a/src/main/scala/ml/tree/DecisionTree.scala +++ b/src/main/scala/ml/tree/DecisionTree.scala @@ -48,18 +48,21 @@ class DecisionTree ( //Calculating length of the features val featureLength = input.first._2.length println("feature length = " + featureLength) - + //Sampling a fraction of the input RDD val sampledData = input.sample(false, fraction, 42).cache() - + println("sampled data size for quantile calculation = " + sampledData.count) + //Sorting the sampled data along each feature and storing it for quantile calculation + println("started sorting sampled data") val sortedSampledFeatures = { - val sortedFeatureArray = new Array[RDD[Double]](featureLength) + val sortedFeatureArray = new Array[Array[Double]](featureLength) 0 until featureLength foreach { - i => sortedFeatureArray(i) = sampledData.map(x => x._2(i) -> None).sortByKey(true).map(_._1) + i => sortedFeatureArray(i) = sampledData.map(x => x._2(i) -> None).sortByKey(true).map(_._1).collect() } sortedFeatureArray } + println("finished sorting sampled data") val numSamples = sampledData.count println("num samples = " + numSamples) @@ -68,11 +71,13 @@ class DecisionTree ( val stride = scala.math.max(numSamples / numSplitPredicates, 1) println("stride = " + stride) - //Calculating all possible splits for the features + //Calculating all possible splits for the features + println("calculating all possible splits for features") val allSplitsList = for { featureIndex <- 0 until featureLength; index <- stride until numSamples - 1 by stride } yield createSplit(featureIndex, index) + println("finished calculating all possible splits for features") //Remove duplicate splits. Especially help for one-hot encoded categorical variables. val allSplits = sparkContext.broadcast(allSplitsList.toSet) @@ -83,7 +88,7 @@ class DecisionTree ( * Find the exact value using feature index and index into the sorted features */ def valueAtRDDIndex(featuresIndex: Long, index: Long): Double = { - sortedSampledFeatures(featuresIndex.toInt).collect()(index.toInt) + sortedSampledFeatures(featuresIndex.toInt)(index.toInt) } /* @@ -94,6 +99,9 @@ class DecisionTree ( } def buildTree(): Node = { + + println("building decision tree") + strategy match { case Strategy("Classification") => new TopClassificationNode(input, allSplits, impurity, strategy, maxDepth) case Strategy("Regression") => { @@ -112,13 +120,13 @@ class DecisionTree ( object DecisionTree { def train( - input: RDD[(Double, Array[Double])], - numSplitPredicates: Int, - strategy: Strategy, - impurity: Impurity, - maxDepth : Int, - fraction : Double, - sparkContext : SparkContext): Option[NodeModel] = { + input: RDD[(Double, Array[Double])], + numSplitPredicates: Int, + strategy: Strategy, + impurity: Impurity, + maxDepth : Int, + fraction : Double, + sparkContext : SparkContext): Option[NodeModel] = { new DecisionTree( input = input, numSplitPredicates = numSplitPredicates, @@ -130,12 +138,4 @@ object DecisionTree { .buildTree .extractModel } -} - - - - - - - - +} \ No newline at end of file From 729a3b1e7533a6fa45cde28b892d92c1bb303bf2 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 19 Oct 2013 16:24:15 -0700 Subject: [PATCH 15/19] moving metrics to new class and changing root node depth to 1 from 0 --- src/main/scala/ml/tree/Metrics.scala | 31 +++++++++++++++++++ src/main/scala/ml/tree/TreeRunner.scala | 18 ++--------- src/main/scala/ml/tree/node/Node.scala | 8 +++-- .../scala/ml/tree/node/decisionNodes.scala | 4 +-- 4 files changed, 40 insertions(+), 21 deletions(-) create mode 100644 src/main/scala/ml/tree/Metrics.scala diff --git a/src/main/scala/ml/tree/Metrics.scala b/src/main/scala/ml/tree/Metrics.scala new file mode 100644 index 0000000..d644bdd --- /dev/null +++ b/src/main/scala/ml/tree/Metrics.scala @@ -0,0 +1,31 @@ +package ml.tree + +import org.apache.spark.SparkContext._ +import ml.tree.node.NodeModel +import org.apache.spark.rdd.RDD + +/* +Helper methods for measuring performance of ML algorithms + */ +object Metrics { + + //TODO: Make these generic MLTable metrics. + def accuracyScore(tree : Option[NodeModel], data : RDD[(Double, Array[Double])]) : Double = { + if (tree.isEmpty) return 1 //TODO: Throw exception + val correctCount = data.filter(y => tree.get.predict(y._2) == y._1).count() + val count = data.count() + print("correct count = " + correctCount) + print("training data count = " + count) + correctCount.toDouble / count + } + + //TODO: Make these generic MLTable metrics + def meanSquaredError(tree : Option[NodeModel], data : RDD[(Double, Array[Double])]) : Double = { + if (tree.isEmpty) return 1 //TODO: Throw exception + val meanSumOfSquares = data.map(y => (tree.get.predict(y._2) - y._1)*(tree.get.predict(y._2) - y._1)).mean() + print("meanSumOfSquares = " + meanSumOfSquares) + meanSumOfSquares + } + + +} diff --git a/src/main/scala/ml/tree/TreeRunner.scala b/src/main/scala/ml/tree/TreeRunner.scala index a36b954..d78fa64 100644 --- a/src/main/scala/ml/tree/TreeRunner.scala +++ b/src/main/scala/ml/tree/TreeRunner.scala @@ -8,6 +8,8 @@ import ml.tree.strategy.Strategy import ml.tree.node.NodeModel import org.apache.spark.rdd.RDD +import ml.tree.Metrics.{accuracyScore,meanSquaredError} + object TreeRunner extends Logging { val usage = """ Usage: TreeRunner [slices] --strategy --trainDataDir path --testDataDir path [--maxDepth num] [--impurity ] [--samplingFractionForSplitCalculation num] @@ -86,21 +88,5 @@ object TreeRunner extends Logging { } - def accuracyScore(tree : Option[NodeModel], data : RDD[(Double, Array[Double])]) : Double = { - if (tree.isEmpty) return 1 //TODO: Throw exception - val correctCount = data.filter(y => tree.get.predict(y._2) == y._1).count() - val count = data.count() - print("correct count = " + correctCount) - print("training data count = " + count) - correctCount.toDouble / count - } - - def meanSquaredError(tree : Option[NodeModel], data : RDD[(Double, Array[Double])]) : Double = { - if (tree.isEmpty) return 1 //TODO: Throw exception - val meanSumOfSquares = data.map(y => (tree.get.predict(y._2) - y._1)*(tree.get.predict(y._2) - y._1)).mean() - print("meanSumOfSquares = " + meanSumOfSquares) - meanSumOfSquares - } - } diff --git a/src/main/scala/ml/tree/node/Node.scala b/src/main/scala/ml/tree/node/Node.scala index b99b216..c179822 100644 --- a/src/main/scala/ml/tree/node/Node.scala +++ b/src/main/scala/ml/tree/node/Node.scala @@ -2,6 +2,8 @@ package ml.tree.node import org.apache.spark.rdd.RDD import ml.tree.split.SplitPredicate +import ml.tree.Metrics._ +import scala.Some /* * Node trait as a template for implementing various types of nodes in the decision tree. @@ -30,13 +32,13 @@ trait Node { //Extract model def extractModel: Option[NodeModel] = { //Add probability logic - if (!splitPredicate.isEmpty) { Some(new NodeModel(splitPredicate, left.extractModel, right.extractModel, depth, isLeaf, Some(prediction))) } + if (!splitPredicate.isEmpty) { + Some(new NodeModel(splitPredicate, left.extractModel, right.extractModel, depth, isLeaf, Some(prediction))) + } else { - // Using -1 as depth Some(new NodeModel(None, None, None, depth, isLeaf, Some(prediction))) } } - //Prediction at the node def prediction: Prediction } diff --git a/src/main/scala/ml/tree/node/decisionNodes.scala b/src/main/scala/ml/tree/node/decisionNodes.scala index 3ba767e..14d9cf4 100644 --- a/src/main/scala/ml/tree/node/decisionNodes.scala +++ b/src/main/scala/ml/tree/node/decisionNodes.scala @@ -143,7 +143,7 @@ abstract class DecisionNode( /* * Top node for building a classification tree */ -class TopClassificationNode(input: RDD[(Double, Array[Double])], allSplits: Broadcast[Set[Split]], impurity: Impurity, strategy: Strategy, maxDepth: Int) extends ClassificationNode(input.cache, 0, List[SplitPredicate](), new NodeStats, allSplits, impurity, strategy, maxDepth) { +class TopClassificationNode(input: RDD[(Double, Array[Double])], allSplits: Broadcast[Set[Split]], impurity: Impurity, strategy: Strategy, maxDepth: Int) extends ClassificationNode(input.cache, 1, List[SplitPredicate](), new NodeStats, allSplits, impurity, strategy, maxDepth) { override def toString() = "[" + left + "[" + "TopNode" + "]" + right + "]" } @@ -170,7 +170,7 @@ class ClassificationNode(data: RDD[(Double, Array[Double])], depth: Int, splitPr /* * Top node for building a regression tree */ -class TopRegressionNode(input: RDD[(Double, Array[Double])], nodeStats: NodeStats, allSplits: Broadcast[Set[Split]], impurity: Impurity, strategy: Strategy, maxDepth: Int) extends RegressionNode(input.cache, 0, List[SplitPredicate](), nodeStats, allSplits, impurity, strategy, maxDepth) { +class TopRegressionNode(input: RDD[(Double, Array[Double])], nodeStats: NodeStats, allSplits: Broadcast[Set[Split]], impurity: Impurity, strategy: Strategy, maxDepth: Int) extends RegressionNode(input.cache, 1, List[SplitPredicate](), nodeStats, allSplits, impurity, strategy, maxDepth) { override def toString() = "[" + left + "[" + "TopNode" + "]" + right + "]" } From d956cd33cbf4ca7e81b545ad2641533c8321a220 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 19 Oct 2013 18:33:38 -0700 Subject: [PATCH 16/19] calculating training error while building decision tree --- src/main/scala/ml/tree/DecisionTree.scala | 17 ++++++++++++++++- src/main/scala/ml/tree/Metrics.scala | 6 +++--- src/main/scala/ml/tree/TreeRunner.scala | 5 +++-- src/main/scala/ml/tree/node/NodeModel.scala | 1 + src/main/scala/ml/tree/node/decisionNodes.scala | 5 ++++- 5 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/main/scala/ml/tree/DecisionTree.scala b/src/main/scala/ml/tree/DecisionTree.scala index 24d10a8..fc0d98a 100644 --- a/src/main/scala/ml/tree/DecisionTree.scala +++ b/src/main/scala/ml/tree/DecisionTree.scala @@ -31,6 +31,10 @@ import scala.Some import ml.tree.strategy.Strategy import ml.tree.split.Split import ml.tree.node._ +import ml.tree.Metrics._ +import scala.Some +import ml.tree.strategy.Strategy +import ml.tree.split.Split /* @@ -127,7 +131,7 @@ object DecisionTree { maxDepth : Int, fraction : Double, sparkContext : SparkContext): Option[NodeModel] = { - new DecisionTree( + val tree = new DecisionTree( input = input, numSplitPredicates = numSplitPredicates, strategy = strategy, @@ -137,5 +141,16 @@ object DecisionTree { sparkContext = sparkContext) .buildTree .extractModel + + println("calculating performance on training data") + val trainingError = { + strategy match { + case Strategy("Classification") => accuracyScore(tree, input) + case Strategy("Regression") => meanSquaredError(tree, input) + } + } + println("error = " + trainingError) + + tree } } \ No newline at end of file diff --git a/src/main/scala/ml/tree/Metrics.scala b/src/main/scala/ml/tree/Metrics.scala index d644bdd..225b636 100644 --- a/src/main/scala/ml/tree/Metrics.scala +++ b/src/main/scala/ml/tree/Metrics.scala @@ -14,8 +14,8 @@ object Metrics { if (tree.isEmpty) return 1 //TODO: Throw exception val correctCount = data.filter(y => tree.get.predict(y._2) == y._1).count() val count = data.count() - print("correct count = " + correctCount) - print("training data count = " + count) + println("correct prediction count = " + correctCount) + println("data count = " + count) correctCount.toDouble / count } @@ -23,7 +23,7 @@ object Metrics { def meanSquaredError(tree : Option[NodeModel], data : RDD[(Double, Array[Double])]) : Double = { if (tree.isEmpty) return 1 //TODO: Throw exception val meanSumOfSquares = data.map(y => (tree.get.predict(y._2) - y._1)*(tree.get.predict(y._2) - y._1)).mean() - print("meanSumOfSquares = " + meanSumOfSquares) + println("meanSumOfSquares = " + meanSumOfSquares) meanSumOfSquares } diff --git a/src/main/scala/ml/tree/TreeRunner.scala b/src/main/scala/ml/tree/TreeRunner.scala index d78fa64..46016af 100644 --- a/src/main/scala/ml/tree/TreeRunner.scala +++ b/src/main/scala/ml/tree/TreeRunner.scala @@ -75,16 +75,17 @@ object TreeRunner extends Logging { println(tree) //println("prediction = " + tree.get.predict(Array(1.0, 2.0))) + println("loading test data") val testData = TreeUtils.loadLabeledData(sc, options.get('testDataDir).get.toString) - + println("calculating performance of test data") val testError = { strategyStr match { case "Classification" => accuracyScore(tree, testData) case "Regression" => meanSquaredError(tree, testData) } } - print("error = " + testError) + println("error = " + testError) } diff --git a/src/main/scala/ml/tree/node/NodeModel.scala b/src/main/scala/ml/tree/node/NodeModel.scala index 0f41000..7045a6e 100644 --- a/src/main/scala/ml/tree/node/NodeModel.scala +++ b/src/main/scala/ml/tree/node/NodeModel.scala @@ -3,6 +3,7 @@ package ml.tree.node import org.apache.spark.mllib.classification.ClassificationModel import org.apache.spark.rdd.RDD import ml.tree.split.SplitPredicate +import ml.tree.Metrics._ /** * The decision tree model class that diff --git a/src/main/scala/ml/tree/node/decisionNodes.scala b/src/main/scala/ml/tree/node/decisionNodes.scala index 14d9cf4..102f594 100644 --- a/src/main/scala/ml/tree/node/decisionNodes.scala +++ b/src/main/scala/ml/tree/node/decisionNodes.scala @@ -8,6 +8,10 @@ import ml.tree.impurity.Impurity import ml.tree.strategy.Strategy import org.apache.spark.util.StatCounter import javax.naming.OperationNotSupportedException +import ml.tree.Metrics._ +import scala.Some +import ml.tree.strategy.Strategy +import ml.tree.split.Split abstract class DecisionNode( val data: RDD[(Double, Array[Double])], @@ -136,7 +140,6 @@ abstract class DecisionNode( (stat.mean, stat.variance, stat.count) } - } From 422ed7d209073b3db8101474a0dff770ae729b29 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 19 Oct 2013 21:30:52 -0700 Subject: [PATCH 17/19] renaming error to accuracy :-) --- src/main/scala/ml/tree/DecisionTree.scala | 2 +- src/main/scala/ml/tree/TreeRunner.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/scala/ml/tree/DecisionTree.scala b/src/main/scala/ml/tree/DecisionTree.scala index fc0d98a..7a6f1ce 100644 --- a/src/main/scala/ml/tree/DecisionTree.scala +++ b/src/main/scala/ml/tree/DecisionTree.scala @@ -149,7 +149,7 @@ object DecisionTree { case Strategy("Regression") => meanSquaredError(tree, input) } } - println("error = " + trainingError) + println("accuracy = " + trainingError) tree } diff --git a/src/main/scala/ml/tree/TreeRunner.scala b/src/main/scala/ml/tree/TreeRunner.scala index 46016af..3343180 100644 --- a/src/main/scala/ml/tree/TreeRunner.scala +++ b/src/main/scala/ml/tree/TreeRunner.scala @@ -85,7 +85,7 @@ object TreeRunner extends Logging { case "Regression" => meanSquaredError(tree, testData) } } - println("error = " + testError) + println("accuracy = " + testError) } From 020069ab5dccf9a763ebee57c7532af2712ed397 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 9 Nov 2013 20:52:05 -0800 Subject: [PATCH 18/19] first attempt at vectorization --- .../scala/ml/tree/node/decisionNodes.scala | 67 +++++++++++++++---- 1 file changed, 53 insertions(+), 14 deletions(-) diff --git a/src/main/scala/ml/tree/node/decisionNodes.scala b/src/main/scala/ml/tree/node/decisionNodes.scala index 102f594..ac865e7 100644 --- a/src/main/scala/ml/tree/node/decisionNodes.scala +++ b/src/main/scala/ml/tree/node/decisionNodes.scala @@ -12,6 +12,7 @@ import ml.tree.Metrics._ import scala.Some import ml.tree.strategy.Strategy import ml.tree.split.Split +import scala.collection.mutable abstract class DecisionNode( val data: RDD[(Double, Array[Double])], @@ -76,25 +77,52 @@ abstract class DecisionNode( strategy match { case Strategy("Classification") => { - val splitWiseCalculations = data.flatMap(sample => { - val label = sample._1 - val features = sample._2 - val leftOrRight = for { - split <- availableSplits.toSeq - featureIndex = split.feature - threshold = split.threshold - } yield { - if (features(featureIndex) <= threshold) (split, "left", label) else (split, "right", label) + //Write a function that takes an RDD and list of splits + //and returns a map of (split, , label) -> count + + val splits = availableSplits.toSeq + + //Modify numLabels to support multiple classes in the future + val numLabels = 2 + val numChildren = 2 + val lenSplits = splits.length + val outputVectorLength = numLabels * numChildren * lenSplits + val vecToVec : RDD[Array[Long]] = data.map( + sample => { + val storage : Array[Long] = new Array[Long](outputVectorLength) + val label = sample._1 + val features = sample._2 + splits.zipWithIndex.foreach{case (split, i) => + val featureIndex = split.feature + val threshold = split.threshold + if (features(featureIndex) <= threshold) { //left node + val index = i*(numLabels*numChildren) + label.toInt + storage(index) = 1 + } else{ //right node + val index = i*(numLabels*numChildren) + numLabels + label.toInt + storage(index) = 1 + } + } + storage } - leftOrRight - }).map(k => (k, 1)) + ) + + val countVecToVec : Array[Long] = vecToVec.reduce((a1,a2) => NodeHelper.sumTwoArrays(a1,a2)) - val gainCalculations = splitWiseCalculations.countByKey() - .toMap //TODO: Hack to go from mutable to immutable map. Clean this up if needed. + //TOOD: Unnecessary step. Use indices directly instead of creating a map. Not a big hit in performance. Optimize later. + var newGainCalculations = Map[(Split,String,Double),Long]() + splits.zipWithIndex.foreach{case(split,i) => + newGainCalculations += ((split,"left",0.0) -> countVecToVec(i*(numLabels*numChildren) + 0)) + newGainCalculations += ((split,"left",1.0) -> countVecToVec(i*(numLabels*numChildren) + 1)) + newGainCalculations += ((split,"right",0.0) -> countVecToVec(i*(numLabels*numChildren) + numLabels + 0)) + newGainCalculations += ((split,"right",1.0) -> countVecToVec(i*(numLabels*numChildren) + numLabels + 1)) + } + //TODO: Vectorize this operation as well val split_gain_list = for ( split <- availableSplits; - gain = impurity.calculateClassificationGain(split, gainCalculations) + //gain = impurity.calculateClassificationGain(split, gainCalculations) + gain = impurity.calculateClassificationGain(split, newGainCalculations) ) yield (split, gain) val split_gain = split_gain_list.reduce(comparePair(_, _)) @@ -217,5 +245,16 @@ class LeafNode(val data: RDD[(Double, Array[Double])]) extends Node { } } +object NodeHelper extends Serializable { + + //There definitely has to be a library function to do this! + def sumTwoArrays(a1 : Array[Long], a2 : Array[Long]) : Array[Long] = { + val storage = new Array[Long](a1.length) + for (i <- 0 until a1.length){storage(i) = a1(i) + a2(i)} + storage + } + +} + From 271d1f4fc2946feff82000c6f152223ba5bfb994 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 10 Nov 2013 20:33:27 -0800 Subject: [PATCH 19/19] using RDD aggregate to count --- .../scala/ml/tree/node/decisionNodes.scala | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/main/scala/ml/tree/node/decisionNodes.scala b/src/main/scala/ml/tree/node/decisionNodes.scala index ac865e7..5b58444 100644 --- a/src/main/scala/ml/tree/node/decisionNodes.scala +++ b/src/main/scala/ml/tree/node/decisionNodes.scala @@ -87,9 +87,9 @@ abstract class DecisionNode( val numChildren = 2 val lenSplits = splits.length val outputVectorLength = numLabels * numChildren * lenSplits - val vecToVec : RDD[Array[Long]] = data.map( + val vecToVec : RDD[Array[Int]] = data.map( sample => { - val storage : Array[Long] = new Array[Long](outputVectorLength) + val storage : Array[Int] = new Array[Int](outputVectorLength) val label = sample._1 val features = sample._2 splits.zipWithIndex.foreach{case (split, i) => @@ -107,7 +107,10 @@ abstract class DecisionNode( } ) - val countVecToVec : Array[Long] = vecToVec.reduce((a1,a2) => NodeHelper.sumTwoArrays(a1,a2)) + //val countVecToVec : Array[Long] = vecToVec.reduce((a1,a2) => NodeHelper.sumTwoArrays(a1,a2)) + val countVecToVec : Array[Long] = + vecToVec.aggregate(new Array[Long](outputVectorLength))(NodeHelper.sumLongIntArrays,NodeHelper.sumTwoLongArrays) + //TOOD: Unnecessary step. Use indices directly instead of creating a map. Not a big hit in performance. Optimize later. var newGainCalculations = Map[(Split,String,Double),Long]() @@ -118,7 +121,6 @@ abstract class DecisionNode( newGainCalculations += ((split,"right",1.0) -> countVecToVec(i*(numLabels*numChildren) + numLabels + 1)) } - //TODO: Vectorize this operation as well val split_gain_list = for ( split <- availableSplits; //gain = impurity.calculateClassificationGain(split, gainCalculations) @@ -174,7 +176,8 @@ abstract class DecisionNode( /* * Top node for building a classification tree */ -class TopClassificationNode(input: RDD[(Double, Array[Double])], allSplits: Broadcast[Set[Split]], impurity: Impurity, strategy: Strategy, maxDepth: Int) extends ClassificationNode(input.cache, 1, List[SplitPredicate](), new NodeStats, allSplits, impurity, strategy, maxDepth) { +class TopClassificationNode(input: RDD[(Double, Array[Double])], allSplits: Broadcast[Set[Split]], impurity: Impurity, strategy: Strategy, maxDepth: Int) + extends ClassificationNode(input.cache, 1, List[SplitPredicate](), new NodeStats, allSplits, impurity, strategy, maxDepth) { override def toString() = "[" + left + "[" + "TopNode" + "]" + right + "]" } @@ -201,7 +204,8 @@ class ClassificationNode(data: RDD[(Double, Array[Double])], depth: Int, splitPr /* * Top node for building a regression tree */ -class TopRegressionNode(input: RDD[(Double, Array[Double])], nodeStats: NodeStats, allSplits: Broadcast[Set[Split]], impurity: Impurity, strategy: Strategy, maxDepth: Int) extends RegressionNode(input.cache, 1, List[SplitPredicate](), nodeStats, allSplits, impurity, strategy, maxDepth) { +class TopRegressionNode(input: RDD[(Double, Array[Double])], nodeStats: NodeStats, allSplits: Broadcast[Set[Split]], impurity: Impurity, strategy: Strategy, maxDepth: Int) + extends RegressionNode(input.cache, 1, List[SplitPredicate](), nodeStats, allSplits, impurity, strategy, maxDepth) { override def toString() = "[" + left + "[" + "TopNode" + "]" + right + "]" } @@ -248,7 +252,14 @@ class LeafNode(val data: RDD[(Double, Array[Double])]) extends Node { object NodeHelper extends Serializable { //There definitely has to be a library function to do this! - def sumTwoArrays(a1 : Array[Long], a2 : Array[Long]) : Array[Long] = { + def sumTwoLongArrays(a1 : Array[Long], a2 : Array[Long]) : Array[Long] = { + val storage = new Array[Long](a1.length) + for (i <- 0 until a1.length){storage(i) = a1(i) + a2(i)} + storage + } + + //There definitely has to be a library function to do this! + def sumLongIntArrays(a1 : Array[Long], a2 : Array[Int]) : Array[Long] = { val storage = new Array[Long](a1.length) for (i <- 0 until a1.length){storage(i) = a1(i) + a2(i)} storage