Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tree #3

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ project/plugins/project/
#Eclipse specific
.classpath
.project

#IDEA specific
.idea
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good idea, thanks.

.idea_modules
156 changes: 156 additions & 0 deletions src/main/scala/ml/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.tree
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should refactor this as mli.ml.tree

import javax.naming.OperationNotSupportedException
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.classification.ClassificationModel
import org.apache.spark.SparkContext
import org.apache.spark.util.StatCounter
import org.apache.spark.Logging
import ml.tree.impurity.{Variance, Entropy, Gini, Impurity}
import ml.tree.strategy.Strategy
import ml.tree.split.{SplitPredicate, Split}
import org.apache.spark.broadcast.Broadcast
import scala.Some
import ml.tree.strategy.Strategy
import ml.tree.split.Split
import ml.tree.node._
import ml.tree.Metrics._
import scala.Some
import ml.tree.strategy.Strategy
import ml.tree.split.Split


/*
* Class for building the Decision Tree model. Should be used for both classification and regression tree.
*/
class DecisionTree (
val input: RDD[(Double, Array[Double])], //input RDD
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great if this could take an MLNumericTable as input with a "targetCol" that indicates which column to use as response. This may impact some design choices later (e.g. enforces a row-oriented perspective).

val maxDepth: Int, // depth of the tree
val numSplitPredicates: Int, // number of bins per features
val fraction: Double, // fraction of the data to be used for performing quantile calculation
val strategy: Strategy, // classification or regression
val impurity: Impurity, // impurity calculation strategy (variance, gini, entropy, etc.)
val sparkContext : SparkContext) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for MLContext here.


//Calculating length of the features
val featureLength = input.first._2.length
println("feature length = " + featureLength)

//Sampling a fraction of the input RDD
val sampledData = input.sample(false, fraction, 42).cache()
println("sampled data size for quantile calculation = " + sampledData.count)

//Sorting the sampled data along each feature and storing it for quantile calculation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're doing exact quantiling here, which is good. The downside here is it requires a sort on each data column, which is potentially expensive.

I wonder how bad it would be to approximate with low-high and split values at intervals of (low-high)/#bins - this way we'd effectively have fixed width histograms in one pass over the data (just collect max and min for each feature). This is pretty outlier sensitive, but maybe worth thinking about.

Finally, I'm wondering how effective this strategy will be on low-variance or highly sparse columns where 90% of the data has the same value. In that case, your conversion of split predicates to sets effectively drops the #of split candidates for a particular column down to close to 0 (or even to 0), even though there might be some signal there. I've never used Decision Trees on data with that kind of skew in practice, but I imagine it's common in e.g. online marketing.

println("started sorting sampled data")
val sortedSampledFeatures = {
val sortedFeatureArray = new Array[Array[Double]](featureLength)
0 until featureLength foreach {
i => sortedFeatureArray(i) = sampledData.map(x => x._2(i) -> None).sortByKey(true).map(_._1).collect()
}
sortedFeatureArray
}
println("finished sorting sampled data")

val numSamples = sampledData.count
println("num samples = " + numSamples)

// Calculating the index to jump to find the quantile points
val stride = scala.math.max(numSamples / numSplitPredicates, 1)
println("stride = " + stride)

//Calculating all possible splits for the features
println("calculating all possible splits for features")
val allSplitsList = for {
featureIndex <- 0 until featureLength;
index <- stride until numSamples - 1 by stride
} yield createSplit(featureIndex, index)
println("finished calculating all possible splits for features")

//Remove duplicate splits. Especially help for one-hot encoded categorical variables.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the current setup handle categorical variables? It seems like we're treating them just like ordered numeric variables now?

val allSplits = sparkContext.broadcast(allSplitsList.toSet)

//for (split <- allSplits) yield println(split)

/*
* Find the exact value using feature index and index into the sorted features
*/
def valueAtRDDIndex(featuresIndex: Long, index: Long): Double = {
sortedSampledFeatures(featuresIndex.toInt)(index.toInt)
}

/*
* Create splits using feature index and index into the sorted features
*/
def createSplit(featureIndex: Int, index: Long): Split = {
new Split(featureIndex, valueAtRDDIndex(featureIndex, index))
}

def buildTree(): Node = {

println("building decision tree")

strategy match {
case Strategy("Classification") => new TopClassificationNode(input, allSplits, impurity, strategy, maxDepth)
case Strategy("Regression") => {
val count = input.count
//TODO: calculate mean and variance together
val variance = input.map(x => x._1).variance
val mean = input.map(x => x._1).mean
val nodeStats = new NodeStats(count = Some(count), variance = Some(variance), mean = Some(mean))
new TopRegressionNode(input, nodeStats,allSplits, impurity, strategy, maxDepth)
}
}
}

}


object DecisionTree {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency with the rest of the MLI codebase, we might want wrapper objects like RegressionTreeModel and RegressionTreeAlgorithm as well as DecisionTreeModel/DecisionTreeAlgorithm and put those in mli.ml.regression and classification respectively. They'd be thin wrappers around DecisionTree.train() (and the common 'tree' module is fine to live in its own area!)

def train(
input: RDD[(Double, Array[Double])],
numSplitPredicates: Int,
strategy: Strategy,
impurity: Impurity,
maxDepth : Int,
fraction : Double,
sparkContext : SparkContext): Option[NodeModel] = {
val tree = new DecisionTree(
input = input,
numSplitPredicates = numSplitPredicates,
strategy = strategy,
impurity = impurity,
maxDepth = maxDepth,
fraction = fraction,
sparkContext = sparkContext)
.buildTree
.extractModel

println("calculating performance on training data")
val trainingError = {
strategy match {
case Strategy("Classification") => accuracyScore(tree, input)
case Strategy("Regression") => meanSquaredError(tree, input)
}
}
println("accuracy = " + trainingError)

tree
}
}
31 changes: 31 additions & 0 deletions src/main/scala/ml/tree/Metrics.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package ml.tree

import org.apache.spark.SparkContext._
import ml.tree.node.NodeModel
import org.apache.spark.rdd.RDD

/*
Helper methods for measuring performance of ML algorithms
*/
object Metrics {

//TODO: Make these generic MLTable metrics.
def accuracyScore(tree : Option[NodeModel], data : RDD[(Double, Array[Double])]) : Double = {
if (tree.isEmpty) return 1 //TODO: Throw exception
val correctCount = data.filter(y => tree.get.predict(y._2) == y._1).count()
val count = data.count()
println("correct prediction count = " + correctCount)
println("data count = " + count)
correctCount.toDouble / count
}

//TODO: Make these generic MLTable metrics
def meanSquaredError(tree : Option[NodeModel], data : RDD[(Double, Array[Double])]) : Double = {
if (tree.isEmpty) return 1 //TODO: Throw exception
val meanSumOfSquares = data.map(y => (tree.get.predict(y._2) - y._1)*(tree.get.predict(y._2) - y._1)).mean()
println("meanSumOfSquares = " + meanSumOfSquares)
meanSumOfSquares
}


}
38 changes: 38 additions & 0 deletions src/main/scala/ml/tree/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#Decision Tree
Decision tree classifiers are both popular supervised learning algorithms and also building blocks for other ensemble learning algorithms such as random forests, boosting, etc. This document discusses its implementation in the Spark project.

#Usage
```
ml.tree.TreeRunner
<master>[slices]
--strategy <Classification,Regression>
--trainDataDir path
--testDataDir path
[--maxDepth num]
[--impurity <Gini,Entropy,Variance>]
[--samplingFractionForSplitCalculation num]
```

#Example
```
sbt/sbt "run-main ml.tree.TreeRunner local[2] --strategy Classification
--trainDataDir ../train_data --testDataDir ../test_data
--maxDepth 1 --impurity Gini --samplingFractionForSplitCalculation 1
```

This command will create a decision tree model using the training data in the *trainDataDir* and calculate test error using the data in the *testDataDir*. The mis-classification error is calculated for a Classification *strategy* and mean squared error is calculated for the Regression *strategy*.

#Performance testing
To be done

#Improvements
* Print to dot files
* Unit tests
* Change fractions to quantiles
* Add logging
* Move metrics to a different package

#Extensions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

* Extremely randomized trees
* Random forest
* Boosting
93 changes: 93 additions & 0 deletions src/main/scala/ml/tree/TreeRunner.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package ml.tree
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice program to have - we should really think about how we want to standardize on command line utils. I think a single driver program that can instantiate models/write models/results out to disk would be really valuable and save a bunch of duplicated effort. For now this is great though!


import org.apache.spark.SparkContext._
import org.apache.spark.{Logging, SparkContext}
import ml.tree.impurity.{Variance, Entropy, Gini}
import ml.tree.strategy.Strategy

import ml.tree.node.NodeModel
import org.apache.spark.rdd.RDD

import ml.tree.Metrics.{accuracyScore,meanSquaredError}

object TreeRunner extends Logging {
val usage = """
Usage: TreeRunner <master>[slices] --strategy <Classification,Regression> --trainDataDir path --testDataDir path [--maxDepth num] [--impurity <Gini,Entropy,Variance>] [--samplingFractionForSplitCalculation num]
"""

def main(args: Array[String]) {

if (args.length < 2) {
System.err.println(usage)
System.exit(1)
}

/**START Experimental*/
System.setProperty("spark.cores.max", "8")
/**END Experimental*/
val sc = new SparkContext(args(0), "Decision Tree Runner",
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))


val arglist = args.toList.drop(1)
type OptionMap = Map[Symbol, Any]

def nextOption(map : OptionMap, list: List[String]) : OptionMap = {
def isSwitch(s : String) = (s(0) == '-')
list match {
case Nil => map
case "--strategy" :: string :: tail => nextOption(map ++ Map('strategy -> string), tail)
case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string), tail)
case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), tail)
case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail)
case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail)
case "--samplingFractionForSplitCalculation" :: string :: tail => nextOption(map ++ Map('samplingFractionForSplitCalculation -> string), tail)
case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail)
case option :: tail => println("Unknown option "+option)
exit(1)
}
}
val options = nextOption(Map(),arglist)
println(options)
//TODO: Add check for acceptable string inputs

val trainData = TreeUtils.loadLabeledData(sc, options.get('trainDataDir).get.toString)
val strategyStr = options.get('strategy).get.toString
val impurityStr = options.getOrElse('impurity,"Gini").toString
val impurity = {
impurityStr match {
case "Gini" => Gini
case "Entropy" => Entropy
case "Variance" => Variance
}
}
val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt
val fraction = options.getOrElse('samplingFractionForSplitCalculation,"1.0").toString.toDouble

val tree = DecisionTree.train(
input = trainData,
numSplitPredicates = 1000,
strategy = new Strategy(strategyStr),
impurity = impurity,
maxDepth = maxDepth,
fraction = fraction,
sparkContext = sc)
println(tree)
//println("prediction = " + tree.get.predict(Array(1.0, 2.0)))

println("loading test data")
val testData = TreeUtils.loadLabeledData(sc, options.get('testDataDir).get.toString)

println("calculating performance of test data")
val testError = {
strategyStr match {
case "Classification" => accuracyScore(tree, testData)
case "Regression" => meanSquaredError(tree, testData)
}
}
println("accuracy = " + testError)

}


}
37 changes: 37 additions & 0 deletions src/main/scala/ml/tree/TreeUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package ml.tree

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD


//TODO: Deprecate this when we find something equivalent in ml utils
/**
* Helper methods to load and save data
* Data format:
* <l>, <f1> <f2> ...
* where <f1>, <f2> are feature values in Double and <l> is the corresponding label as Double.
*/
object TreeUtils {

/**
* @param sc SparkContext
* @param dir Directory to the input data files.
* @return An RDD of tuples. For each tuple, the first element is the label, and the second
* element represents the feature values (an array of Double).
*/
def loadLabeledData(sc: SparkContext, dir: String): RDD[(Double, Array[Double])] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably just be using the standard MLContext.loadCsvFile here (though it should probably generate an MLNumericTable).

sc.textFile(dir).map { line =>
val parts = line.trim().split(",")
val label = parts(0).toDouble
val features = parts.slice(1,parts.length).map(_.toDouble)
//val features = parts.slice(1, 30).map(_.toDouble)
(label, features)
}
}

def saveLabeledData(data: RDD[(Double, Array[Double])], dir: String) {
val dataStr = data.map(x => x._1 + "," + x._2.mkString(" "))
dataStr.saveAsTextFile(dir)
}

}
Loading