@@ -40,14 +40,12 @@ private class Aggregator(
4040 var totalCnt : Double ,
4141 val nnz : BV [Double ],
4242 val currMax : BV [Double ],
43- val currMin : BV [Double ]) extends VectorRDDStatisticalSummary {
44- nnz.activeIterator.foreach {
45- case (id, 0.0 ) =>
46- currMax(id) = 0.0
47- currMin(id) = 0.0
48- case _ =>
43+ val currMin : BV [Double ]) extends VectorRDDStatisticalSummary with Serializable {
44+
45+ override def mean (): Vector = {
46+ Vectors .fromBreeze(currMean :* nnz :/ totalCnt)
4947 }
50- override def mean () : Vector = Vectors .fromBreeze(currMean :* nnz :/ totalCnt)
48+
5149 override def variance (): Vector = {
5250 val deltaMean = currMean
5351 val realM2n = currM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
@@ -58,8 +56,23 @@ private class Aggregator(
5856 override def totalCount (): Long = totalCnt.toLong
5957
6058 override def numNonZeros (): Vector = Vectors .fromBreeze(nnz)
61- override def max (): Vector = Vectors .fromBreeze(currMax)
62- override def min (): Vector = Vectors .fromBreeze(currMin)
59+
60+ override def max (): Vector = {
61+ nnz.activeIterator.foreach {
62+ case (id, 0.0 ) => currMax(id) = 0.0
63+ case _ =>
64+ }
65+ Vectors .fromBreeze(currMax)
66+ }
67+
68+ override def min (): Vector = {
69+ nnz.activeIterator.foreach {
70+ case (id, 0.0 ) => currMin(id) = 0.0
71+ case _ =>
72+ }
73+ Vectors .fromBreeze(currMin)
74+ }
75+
6376 /**
6477 * Aggregate function used for aggregating elements in a worker together.
6578 */
@@ -75,15 +88,19 @@ private class Aggregator(
7588 currM2n(id) += (value - currMean(id)) * (value - tmpPrevMean)
7689
7790 nnz(id) += 1.0
78- totalCnt += 1.0
7991 }
92+
93+ totalCnt += 1.0
8094 this
8195 }
96+
8297 /**
8398 * Combine function used for combining intermediate results together from every worker.
8499 */
85- def merge (other : this .type ): this .type = {
100+ def merge (other : Aggregator ): this .type = {
101+
86102 totalCnt += other.totalCnt
103+
87104 val deltaMean = currMean - other.currMean
88105
89106 other.currMean.activeIterator.foreach {
@@ -114,132 +131,30 @@ private class Aggregator(
114131 }
115132}
116133
117- case class VectorRDDStatisticalAggregator (
118- mean : BV [Double ],
119- statCnt : BV [Double ],
120- totalCnt : Double ,
121- nnz : BV [Double ],
122- currMax : BV [Double ],
123- currMin : BV [Double ])
124-
125134/**
126135 * Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector ]] through an
127136 * implicit conversion. Import `org.apache.spark.MLContext._` at the top of your program to use
128137 * these functions.
129138 */
130139class VectorRDDFunctions (self : RDD [Vector ]) extends Serializable {
131140
132- /**
133- * Aggregate function used for aggregating elements in a worker together.
134- */
135- private def seqOp (
136- aggregator : VectorRDDStatisticalAggregator ,
137- currData : BV [Double ]): VectorRDDStatisticalAggregator = {
138- aggregator match {
139- case VectorRDDStatisticalAggregator (prevMean, prevM2n, cnt, nnzVec, maxVec, minVec) =>
140- currData.activeIterator.foreach {
141- case (id, 0.0 ) =>
142- case (id, value) =>
143- if (maxVec(id) < value) maxVec(id) = value
144- if (minVec(id) > value) minVec(id) = value
145-
146- val tmpPrevMean = prevMean(id)
147- prevMean(id) = (prevMean(id) * cnt + value) / (cnt + 1.0 )
148- prevM2n(id) += (value - prevMean(id)) * (value - tmpPrevMean)
149-
150- nnzVec(id) += 1.0
151- }
152-
153- VectorRDDStatisticalAggregator (
154- prevMean,
155- prevM2n,
156- cnt + 1.0 ,
157- nnzVec,
158- maxVec,
159- minVec)
160- }
161- }
162-
163- /**
164- * Combine function used for combining intermediate results together from every worker.
165- */
166- private def combOp (
167- statistics1 : VectorRDDStatisticalAggregator ,
168- statistics2 : VectorRDDStatisticalAggregator ): VectorRDDStatisticalAggregator = {
169- (statistics1, statistics2) match {
170- case (VectorRDDStatisticalAggregator (mean1, m2n1, cnt1, nnz1, max1, min1),
171- VectorRDDStatisticalAggregator (mean2, m2n2, cnt2, nnz2, max2, min2)) =>
172- val totalCnt = cnt1 + cnt2
173- val deltaMean = mean2 - mean1
174-
175- mean2.activeIterator.foreach {
176- case (id, 0.0 ) =>
177- case (id, value) =>
178- mean1(id) = (mean1(id) * nnz1(id) + mean2(id) * nnz2(id)) / (nnz1(id) + nnz2(id))
179- }
180-
181- m2n2.activeIterator.foreach {
182- case (id, 0.0 ) =>
183- case (id, value) =>
184- m2n1(id) +=
185- value + deltaMean(id) * deltaMean(id) * nnz1(id) * nnz2(id) / (nnz1(id)+ nnz2(id))
186- }
187-
188- max2.activeIterator.foreach {
189- case (id, value) =>
190- if (max1(id) < value) max1(id) = value
191- }
192-
193- min2.activeIterator.foreach {
194- case (id, value) =>
195- if (min1(id) > value) min1(id) = value
196- }
197-
198- axpy(1.0 , nnz2, nnz1)
199- VectorRDDStatisticalAggregator (mean1, m2n1, totalCnt, nnz1, max1, min1)
200- }
201- }
202-
203141 /**
204142 * Compute full column-wise statistics for the RDD with the size of Vector as input parameter.
205143 */
206- def summarizeStatistics (): VectorRDDStatisticalAggregator = {
144+ def summarizeStatistics (): VectorRDDStatisticalSummary = {
207145 val size = self.take(1 ).head.size
208- val zeroValue = VectorRDDStatisticalAggregator (
146+
147+ val zeroValue = new Aggregator (
209148 BV .zeros[Double ](size),
210149 BV .zeros[Double ](size),
211150 0.0 ,
212151 BV .zeros[Double ](size),
213152 BV .fill(size)(Double .MinValue ),
214153 BV .fill(size)(Double .MaxValue ))
215154
216- val VectorRDDStatisticalAggregator (currMean, currM2n, totalCnt, nnz, currMax, currMin) =
217- self.map(_.toBreeze).aggregate(zeroValue)(seqOp, combOp)
218-
219- // solve real mean
220- val realMean = currMean :* nnz :/ totalCnt
221-
222- // solve real m2n
223- val deltaMean = currMean
224- val realM2n = currM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
225-
226- // remove the initial value in max and min, i.e. the Double.MaxValue or Double.MinValue.
227- nnz.activeIterator.foreach {
228- case (id, 0.0 ) =>
229- currMax(id) = 0.0
230- currMin(id) = 0.0
231- case _ =>
232- }
233-
234- // get variance
235- realM2n :/= totalCnt
236-
237- VectorRDDStatisticalAggregator (
238- realMean,
239- realM2n,
240- totalCnt,
241- nnz,
242- currMax,
243- currMin)
155+ self.map(_.toBreeze).aggregate[Aggregator ](zeroValue)(
156+ (aggregator, data) => aggregator.add(data),
157+ (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
158+ )
244159 }
245160}
0 commit comments