@@ -18,20 +18,109 @@ package org.apache.spark.mllib.rdd
1818
1919import breeze .linalg .{axpy , Vector => BV }
2020
21- import org .apache .spark .mllib .linalg .Vector
21+ import org .apache .spark .mllib .linalg .{ Vectors , Vector }
2222import org .apache .spark .rdd .RDD
2323
2424/**
2525 * Case class of the summary statistics, including mean, variance, count, max, min, and non-zero
2626 * elements count.
2727 */
28+ trait VectorRDDStatisticalSummary {
29+ def mean (): Vector
30+ def variance (): Vector
31+ def totalCount (): Long
32+ def numNonZeros (): Vector
33+ def max (): Vector
34+ def min (): Vector
35+ }
36+
37+ private class Aggregator (
38+ val currMean : BV [Double ],
39+ val currM2n : BV [Double ],
40+ var totalCnt : Double ,
41+ val nnz : BV [Double ],
42+ val currMax : BV [Double ],
43+ val currMin : BV [Double ]) extends VectorRDDStatisticalSummary {
44+ nnz.activeIterator.foreach {
45+ case (id, 0.0 ) =>
46+ currMax(id) = 0.0
47+ currMin(id) = 0.0
48+ case _ =>
49+ }
50+ override def mean (): Vector = Vectors .fromBreeze(currMean :* nnz :/ totalCnt)
51+ override def variance (): Vector = {
52+ val deltaMean = currMean
53+ val realM2n = currM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
54+ realM2n :/= totalCnt
55+ Vectors .fromBreeze(realM2n)
56+ }
57+
58+ override def totalCount (): Long = totalCnt.toLong
59+
60+ override def numNonZeros (): Vector = Vectors .fromBreeze(nnz)
61+ override def max (): Vector = Vectors .fromBreeze(currMax)
62+ override def min (): Vector = Vectors .fromBreeze(currMin)
63+ /**
64+ * Aggregate function used for aggregating elements in a worker together.
65+ */
66+ def add (currData : BV [Double ]): this .type = {
67+ currData.activeIterator.foreach {
68+ case (id, 0.0 ) =>
69+ case (id, value) =>
70+ if (currMax(id) < value) currMax(id) = value
71+ if (currMin(id) > value) currMin(id) = value
72+
73+ val tmpPrevMean = currMean(id)
74+ currMean(id) = (currMean(id) * totalCnt + value) / (totalCnt + 1.0 )
75+ currM2n(id) += (value - currMean(id)) * (value - tmpPrevMean)
76+
77+ nnz(id) += 1.0
78+ totalCnt += 1.0
79+ }
80+ this
81+ }
82+ /**
83+ * Combine function used for combining intermediate results together from every worker.
84+ */
85+ def merge (other : this .type ): this .type = {
86+ totalCnt += other.totalCnt
87+ val deltaMean = currMean - other.currMean
88+
89+ other.currMean.activeIterator.foreach {
90+ case (id, 0.0 ) =>
91+ case (id, value) =>
92+ currMean(id) = (currMean(id) * nnz(id) + other.currMean(id) * other.nnz(id)) / (nnz(id) + other.nnz(id))
93+ }
94+
95+ other.currM2n.activeIterator.foreach {
96+ case (id, 0.0 ) =>
97+ case (id, value) =>
98+ currM2n(id) +=
99+ value + deltaMean(id) * deltaMean(id) * nnz(id) * other.nnz(id) / (nnz(id)+ other.nnz(id))
100+ }
101+
102+ other.currMax.activeIterator.foreach {
103+ case (id, value) =>
104+ if (currMax(id) < value) currMax(id) = value
105+ }
106+
107+ other.currMin.activeIterator.foreach {
108+ case (id, value) =>
109+ if (currMin(id) > value) currMin(id) = value
110+ }
111+
112+ axpy(1.0 , other.nnz, nnz)
113+ this
114+ }
115+ }
116+
28117case class VectorRDDStatisticalAggregator (
29118 mean : BV [Double ],
30- statCounter : BV [Double ],
31- totalCount : Double ,
32- numNonZeros : BV [Double ],
33- max : BV [Double ],
34- min : BV [Double ])
119+ statCnt : BV [Double ],
120+ totalCnt : Double ,
121+ nnz : BV [Double ],
122+ currMax : BV [Double ],
123+ currMin : BV [Double ])
35124
36125/**
37126 * Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector ]] through an
0 commit comments