Skip to content

Commit 92cedce

Browse files
committed
basic building blocks for intermediate RDD calculation. untested.
Signed-off-by: Manish Amde <[email protected]>
1 parent cd53eae commit 92cedce

File tree

6 files changed

+262
-15
lines changed

6 files changed

+262
-15
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 164 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717

1818
package org.apache.spark.mllib.tree
1919

20+
import org.apache.spark.SparkContext._
2021
import org.apache.spark.rdd.RDD
22+
import org.apache.spark.mllib.tree.model._
23+
import org.apache.spark.Logging
2124
import org.apache.spark.mllib.regression.LabeledPoint
22-
import org.apache.spark.mllib.tree.model.{Split, Bin, DecisionTreeModel}
25+
import org.apache.spark.mllib.tree.model.Split
2326

2427

2528
class DecisionTree(val strategy : Strategy) {
@@ -30,25 +33,180 @@ class DecisionTree(val strategy : Strategy) {
3033
input.cache()
3134

3235
//TODO: Find all splits and bins using quantiles including support for categorical features, single-pass
36+
//TODO: Think about broadcasting this
3337
val (splits, bins) = DecisionTree.find_splits_bins(input, strategy)
3438

3539
//TODO: Level-wise training of tree and obtain Decision Tree model
3640

41+
val maxDepth = strategy.maxDepth
42+
43+
val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1
44+
val filters = new Array[List[Filter]](maxNumNodes)
45+
46+
for (level <- 0 until maxDepth){
47+
//Find best split for all nodes at a level
48+
val numNodes= scala.math.pow(2,level).toInt
49+
val bestSplits = DecisionTree.findBestSplits(input, strategy, level, filters,splits,bins)
50+
//TODO: update filters and decision tree model
51+
}
3752

3853
return new DecisionTreeModel()
3954
}
4055

4156
}
4257

43-
object DecisionTree {
58+
object DecisionTree extends Logging {
59+
60+
def findBestSplits(
61+
input : RDD[LabeledPoint],
62+
strategy: Strategy,
63+
level: Int,
64+
filters : Array[List[Filter]],
65+
splits : Array[Array[Split]],
66+
bins : Array[Array[Bin]]) : Array[Split] = {
67+
68+
def findParentFilters(nodeIndex: Int): List[Filter] = {
69+
if (level == 0) {
70+
List[Filter]()
71+
} else {
72+
val nodeFilterIndex = scala.math.pow(2, level).toInt + nodeIndex
73+
val parentFilterIndex = nodeFilterIndex / 2
74+
filters(parentFilterIndex)
75+
}
76+
}
77+
78+
def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = {
79+
80+
for (filter <- parentFilters) {
81+
val features = labeledPoint.features
82+
val featureIndex = filter.split.feature
83+
val threshold = filter.split.threshold
84+
val comparison = filter.comparison
85+
comparison match {
86+
case(-1) => if (features(featureIndex) > threshold) return false
87+
case(0) => if (features(featureIndex) != threshold) return false
88+
case(1) => if (features(featureIndex) <= threshold) return false
89+
}
90+
}
91+
true
92+
}
93+
94+
def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = {
95+
96+
//TODO: Do binary search
97+
for (binIndex <- 0 until strategy.numSplits) {
98+
val bin = bins(featureIndex)(binIndex)
99+
//TODO: Remove this requirement post basic functional testing
100+
require(bin.lowSplit.feature == featureIndex)
101+
require(bin.highSplit.feature == featureIndex)
102+
val lowThreshold = bin.lowSplit.threshold
103+
val highThreshold = bin.highSplit.threshold
104+
val features = labeledPoint.features
105+
if ((lowThreshold < features(featureIndex)) & (highThreshold < features(featureIndex))) {
106+
return binIndex
107+
}
108+
}
109+
throw new UnknownError("no bin was found.")
110+
111+
}
112+
def findBinsForLevel: Array[Double] = {
113+
114+
val numNodes = scala.math.pow(2, level).toInt
115+
//Find the number of features by looking at the first sample
116+
val numFeatures = input.take(1)(0).features.length
117+
118+
//TODO: Bit pack more by removing redundant label storage
119+
// calculating bin index and label per feature per node
120+
val arr = new Array[Double](2 * numFeatures * numNodes)
121+
for (nodeIndex <- 0 until numNodes) {
122+
val parentFilters = findParentFilters(nodeIndex)
123+
//Find out whether the sample qualifies for the particular node
124+
val sampleValid = isSampleValid(parentFilters, labeledPoint)
125+
val shift = 2 * numFeatures * nodeIndex
126+
if (sampleValid) {
127+
//Add to invalid bin index -1
128+
for (featureIndex <- shift until (shift + numFeatures) by 2) {
129+
arr(featureIndex + 1) = -1
130+
arr(featureIndex + 2) = labeledPoint.label
131+
}
132+
} else {
133+
for (featureIndex <- 0 until numFeatures) {
134+
arr(shift + (featureIndex * 2) + 1) = findBin(featureIndex, labeledPoint)
135+
arr(shift + (featureIndex * 2) + 2) = labeledPoint.label
136+
}
137+
}
138+
139+
}
140+
arr
141+
}
142+
143+
val binMappedRDD = input.map(labeledPoint => findBinsForLevel)
144+
//calculate bin aggregates
145+
//find best split
146+
147+
148+
Array[Split]()
149+
}
150+
44151
def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = {
152+
45153
val numSplits = strategy.numSplits
154+
logDebug("numSplits = " + numSplits)
155+
156+
//Calculate the number of sample for approximate quantile calculation
46157
//TODO: Justify this calculation
47-
val requiredSamples : Long = numSplits*numSplits
48-
val count : Long = input.count()
49-
val numSamples : Long = if (requiredSamples < count) requiredSamples else count
158+
val requiredSamples = numSplits*numSplits
159+
val count = input.count()
160+
val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
161+
logDebug("fraction of data used for calculating quantiles = " + fraction)
162+
163+
//sampled input for RDD calculation
164+
val sampledInput = input.sample(false, fraction, 42).collect()
165+
val numSamples = sampledInput.length
166+
167+
require(numSamples > numSplits, "length of input samples should be greater than numSplits")
168+
169+
//Find the number of features by looking at the first sample
50170
val numFeatures = input.take(1)(0).features.length
51-
(Array.ofDim[Split](numFeatures,numSplits),Array.ofDim[Bin](numFeatures,numSplits))
171+
172+
strategy.quantileCalculationStrategy match {
173+
case "sort" => {
174+
val splits = Array.ofDim[Split](numFeatures,numSplits-1)
175+
val bins = Array.ofDim[Bin](numFeatures,numSplits)
176+
177+
//Find all splits
178+
for (featureIndex <- 0 until numFeatures){
179+
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
180+
val stride : Double = numSamples.toDouble/numSplits
181+
for (index <- 0 until numSplits-1) {
182+
val sampleIndex = (index+1)*stride.toInt
183+
val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous")
184+
splits(featureIndex)(index) = split
185+
}
186+
}
187+
188+
//Find all bins
189+
for (featureIndex <- 0 until numFeatures){
190+
bins(featureIndex)(0)
191+
= new Bin(new DummyLowSplit("continuous"),splits(featureIndex)(0),"continuous")
192+
for (index <- 1 until numSplits - 1){
193+
val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),"continuous")
194+
bins(featureIndex)(index) = bin
195+
}
196+
bins(featureIndex)(numSplits-1)
197+
= new Bin(splits(featureIndex)(numSplits-3),new DummyHighSplit("continuous"),"continuous")
198+
}
199+
200+
(splits,bins)
201+
}
202+
case "minMax" => {
203+
(Array.ofDim[Split](numFeatures,numSplits),Array.ofDim[Bin](numFeatures,numSplits+2))
204+
}
205+
case "approximateHistogram" => {
206+
throw new UnsupportedOperationException("approximate histogram not supported yet.")
207+
}
208+
209+
}
52210
}
53211

54212
}

mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@ class Strategy (
2323
val impurity : Impurity,
2424
val maxDepth : Int,
2525
val numSplits : Int,
26-
val quantileCalculationStrategy : String = "sampleAndSort") {
26+
val quantileCalculationStrategy : String = "sort") {
2727

2828
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@
1616
*/
1717
package org.apache.spark.mllib.tree.model
1818

19-
case class Bin(kind : String, lowSplit : Split, highSplit : Split) {
19+
case class Bin(lowSplit : Split, highSplit : Split, kind : String) {
2020

2121
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.mllib.tree.model
18+
19+
case class Filter(split : Split, comparison : Int) {
20+
// Comparison -1,0,1 signifies <.=,>
21+
override def toString = " split = " + split + "comparison = " + comparison
22+
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,11 @@
1616
*/
1717
package org.apache.spark.mllib.tree.model
1818

19-
case class Split(
20-
val feature: Int,
21-
val threshold : Double,
22-
val kind : String) {
23-
19+
case class Split(feature: Int, threshold : Double, kind : String){
20+
override def toString = "Feature = " + feature + ", threshold = " + threshold + ", kind = " + kind
2421
}
2522

26-
class dummyLowSplit(kind : String) extends Split(Int.MinValue, Double.MinValue, kind)
23+
class DummyLowSplit(kind : String) extends Split(Int.MinValue, Double.MinValue, kind)
2724

28-
class dummyHighSplit(kind : String) extends Split(Int.MaxValue, Double.MaxValue, kind)
25+
class DummyHighSplit(kind : String) extends Split(Int.MaxValue, Double.MaxValue, kind)
2926

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.mllib.tree
18+
19+
import scala.util.Random
20+
21+
import org.scalatest.BeforeAndAfterAll
22+
import org.scalatest.FunSuite
23+
24+
import org.apache.spark.SparkContext
25+
import org.apache.spark.SparkContext._
26+
27+
import org.jblas._
28+
import org.apache.spark.rdd.RDD
29+
import org.apache.spark.mllib.regression.LabeledPoint
30+
import org.apache.spark.mllib.tree.impurity.Gini
31+
32+
class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
33+
34+
@transient private var sc: SparkContext = _
35+
36+
override def beforeAll() {
37+
sc = new SparkContext("local", "test")
38+
}
39+
40+
override def afterAll() {
41+
sc.stop()
42+
System.clearProperty("spark.driver.port")
43+
}
44+
45+
test("split and bin calculation"){
46+
val arr = DecisionTreeSuite.generateReverseOrderedLabeledPoints()
47+
assert(arr.length == 1000)
48+
val rdd = sc.parallelize(arr)
49+
val strategy = new Strategy("regression",Gini,3,100,"sort")
50+
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
51+
assert(splits.length==2)
52+
assert(bins.length==2)
53+
assert(splits(0).length==99)
54+
assert(bins(0).length==100)
55+
println(splits(1)(98))
56+
}
57+
}
58+
59+
object DecisionTreeSuite {
60+
61+
def generateReverseOrderedLabeledPoints() : Array[LabeledPoint] = {
62+
val arr = new Array[LabeledPoint](1000)
63+
for (i <- 0 until 1000){
64+
val lp = new LabeledPoint(1.0,Array(i.toDouble,1000.0-i))
65+
arr(i) = lp
66+
}
67+
arr
68+
}
69+
70+
}

0 commit comments

Comments
 (0)