@@ -306,6 +306,20 @@ object DecisionTree extends Serializable with Logging {
306306 arr
307307 }
308308
309+ /**
310+ Performs a sequential aggregation over a partition for classification.
311+
312+ for p bins, k features, l nodes (level = log2(l)) storage is of the form:
313+ b111_left_count,b111_right_count, .... , ..
314+ .. bpk1_left_count, bpk1_right_count, .... , ..
315+ .. bpkl_left_count, bpkl_right_count
316+
317+ @param agg Array[Double] storing aggregate calculation of size
318+ 2*numSplits*numFeatures*numNodes for classification
319+ @param arr Array[Double] of size 1+(numFeatures*numNodes)
320+ @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes
321+ for classification
322+ */
309323 def classificationBinSeqOp (arr : Array [Double ], agg : Array [Double ]) {
310324 for (node <- 0 until numNodes) {
311325 val validSignalIndex = 1 + numFeatures * node
@@ -326,6 +340,20 @@ object DecisionTree extends Serializable with Logging {
326340 }
327341 }
328342
343+ /**
344+ Performs a sequential aggregation over a partition for regression.
345+
346+ for p bins, k features, l nodes (level = log2(l)) storage is of the form:
347+ b111_count,b111_sum, b111_sum_squares .... , ..
348+ .. bpk1_count, bpk1_sum, bpk1_sum_squares, .... , ..
349+ .. bpkl_count, bpkl_sum, bpkl_sum_squares
350+
351+ @param agg Array[Double] storing aggregate calculation of size
352+ 3*numSplits*numFeatures*numNodes for classification
353+ @param arr Array[Double] of size 1+(numFeatures*numNodes)
354+ @return Array[Double] storing aggregate calculation of size 3*numSplits*numFeatures*numNodes
355+ for regression
356+ */
329357 def regressionBinSeqOp (arr : Array [Double ], agg : Array [Double ]) {
330358 for (node <- 0 until numNodes) {
331359 val validSignalIndex = 1 + numFeatures * node
@@ -354,11 +382,11 @@ object DecisionTree extends Serializable with Logging {
354382 .. bpk1_left_count, bpk1_right_count, .... , ..
355383 .. bpkl_left_count, bpkl_right_count
356384
357- @param agg Array[Double] storing aggregate calculation of size
358- 2 *numSplits*numFeatures*numNodes for classification
385+ @param agg Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification
386+ and 3 *numSplits*numFeatures*numNodes for regression
359387 @param arr Array[Double] of size 1+(numFeatures*numNodes)
360- @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes
361- for classification
388+ @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification
389+ and 3*numSplits*numFeatures*numNodes for regression
362390 */
363391 def binSeqOp (agg : Array [Double ], arr : Array [Double ]) : Array [Double ] = {
364392 strategy.algo match {
@@ -411,7 +439,15 @@ object DecisionTree extends Serializable with Logging {
411439 logDebug(" binAggregates.length = " + binAggregates.length)
412440 // binAggregates.foreach(x => logDebug(x))
413441
414-
442+ /**
443+ * Calculates the information gain for all splits
444+ * @param leftNodeAgg left node aggregates
445+ * @param featureIndex feature index
446+ * @param splitIndex split index
447+ * @param rightNodeAgg right node aggregate
448+ * @param topImpurity impurity of the parent node
449+ * @return information gain and statistics for all splits
450+ */
415451 def calculateGainForSplit (leftNodeAgg : Array [Array [Double ]],
416452 featureIndex : Int ,
417453 splitIndex : Int ,
0 commit comments