Skip to content

Commit 967d041

Browse files
committed
full revision with Aggregator class
1 parent 138300c commit 967d041

File tree

2 files changed

+42
-127
lines changed

2 files changed

+42
-127
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 35 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,12 @@ private class Aggregator(
4040
var totalCnt: Double,
4141
val nnz: BV[Double],
4242
val currMax: BV[Double],
43-
val currMin: BV[Double]) extends VectorRDDStatisticalSummary {
44-
nnz.activeIterator.foreach {
45-
case (id, 0.0) =>
46-
currMax(id) = 0.0
47-
currMin(id) = 0.0
48-
case _ =>
43+
val currMin: BV[Double]) extends VectorRDDStatisticalSummary with Serializable {
44+
45+
override def mean(): Vector = {
46+
Vectors.fromBreeze(currMean :* nnz :/ totalCnt)
4947
}
50-
override def mean(): Vector = Vectors.fromBreeze(currMean :* nnz :/ totalCnt)
48+
5149
override def variance(): Vector = {
5250
val deltaMean = currMean
5351
val realM2n = currM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
@@ -58,8 +56,23 @@ private class Aggregator(
5856
override def totalCount(): Long = totalCnt.toLong
5957

6058
override def numNonZeros(): Vector = Vectors.fromBreeze(nnz)
61-
override def max(): Vector = Vectors.fromBreeze(currMax)
62-
override def min(): Vector = Vectors.fromBreeze(currMin)
59+
60+
override def max(): Vector = {
61+
nnz.activeIterator.foreach {
62+
case (id, 0.0) => currMax(id) = 0.0
63+
case _ =>
64+
}
65+
Vectors.fromBreeze(currMax)
66+
}
67+
68+
override def min(): Vector = {
69+
nnz.activeIterator.foreach {
70+
case (id, 0.0) => currMin(id) = 0.0
71+
case _ =>
72+
}
73+
Vectors.fromBreeze(currMin)
74+
}
75+
6376
/**
6477
* Aggregate function used for aggregating elements in a worker together.
6578
*/
@@ -75,15 +88,19 @@ private class Aggregator(
7588
currM2n(id) += (value - currMean(id)) * (value - tmpPrevMean)
7689

7790
nnz(id) += 1.0
78-
totalCnt += 1.0
7991
}
92+
93+
totalCnt += 1.0
8094
this
8195
}
96+
8297
/**
8398
* Combine function used for combining intermediate results together from every worker.
8499
*/
85-
def merge(other: this.type): this.type = {
100+
def merge(other: Aggregator): this.type = {
101+
86102
totalCnt += other.totalCnt
103+
87104
val deltaMean = currMean - other.currMean
88105

89106
other.currMean.activeIterator.foreach {
@@ -114,132 +131,30 @@ private class Aggregator(
114131
}
115132
}
116133

117-
case class VectorRDDStatisticalAggregator(
118-
mean: BV[Double],
119-
statCnt: BV[Double],
120-
totalCnt: Double,
121-
nnz: BV[Double],
122-
currMax: BV[Double],
123-
currMin: BV[Double])
124-
125134
/**
126135
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an
127136
* implicit conversion. Import `org.apache.spark.MLContext._` at the top of your program to use
128137
* these functions.
129138
*/
130139
class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
131140

132-
/**
133-
* Aggregate function used for aggregating elements in a worker together.
134-
*/
135-
private def seqOp(
136-
aggregator: VectorRDDStatisticalAggregator,
137-
currData: BV[Double]): VectorRDDStatisticalAggregator = {
138-
aggregator match {
139-
case VectorRDDStatisticalAggregator(prevMean, prevM2n, cnt, nnzVec, maxVec, minVec) =>
140-
currData.activeIterator.foreach {
141-
case (id, 0.0) =>
142-
case (id, value) =>
143-
if (maxVec(id) < value) maxVec(id) = value
144-
if (minVec(id) > value) minVec(id) = value
145-
146-
val tmpPrevMean = prevMean(id)
147-
prevMean(id) = (prevMean(id) * cnt + value) / (cnt + 1.0)
148-
prevM2n(id) += (value - prevMean(id)) * (value - tmpPrevMean)
149-
150-
nnzVec(id) += 1.0
151-
}
152-
153-
VectorRDDStatisticalAggregator(
154-
prevMean,
155-
prevM2n,
156-
cnt + 1.0,
157-
nnzVec,
158-
maxVec,
159-
minVec)
160-
}
161-
}
162-
163-
/**
164-
* Combine function used for combining intermediate results together from every worker.
165-
*/
166-
private def combOp(
167-
statistics1: VectorRDDStatisticalAggregator,
168-
statistics2: VectorRDDStatisticalAggregator): VectorRDDStatisticalAggregator = {
169-
(statistics1, statistics2) match {
170-
case (VectorRDDStatisticalAggregator(mean1, m2n1, cnt1, nnz1, max1, min1),
171-
VectorRDDStatisticalAggregator(mean2, m2n2, cnt2, nnz2, max2, min2)) =>
172-
val totalCnt = cnt1 + cnt2
173-
val deltaMean = mean2 - mean1
174-
175-
mean2.activeIterator.foreach {
176-
case (id, 0.0) =>
177-
case (id, value) =>
178-
mean1(id) = (mean1(id) * nnz1(id) + mean2(id) * nnz2(id)) / (nnz1(id) + nnz2(id))
179-
}
180-
181-
m2n2.activeIterator.foreach {
182-
case (id, 0.0) =>
183-
case (id, value) =>
184-
m2n1(id) +=
185-
value + deltaMean(id) * deltaMean(id) * nnz1(id) * nnz2(id) / (nnz1(id)+nnz2(id))
186-
}
187-
188-
max2.activeIterator.foreach {
189-
case (id, value) =>
190-
if (max1(id) < value) max1(id) = value
191-
}
192-
193-
min2.activeIterator.foreach {
194-
case (id, value) =>
195-
if (min1(id) > value) min1(id) = value
196-
}
197-
198-
axpy(1.0, nnz2, nnz1)
199-
VectorRDDStatisticalAggregator(mean1, m2n1, totalCnt, nnz1, max1, min1)
200-
}
201-
}
202-
203141
/**
204142
* Compute full column-wise statistics for the RDD with the size of Vector as input parameter.
205143
*/
206-
def summarizeStatistics(): VectorRDDStatisticalAggregator = {
144+
def summarizeStatistics(): VectorRDDStatisticalSummary = {
207145
val size = self.take(1).head.size
208-
val zeroValue = VectorRDDStatisticalAggregator(
146+
147+
val zeroValue = new Aggregator(
209148
BV.zeros[Double](size),
210149
BV.zeros[Double](size),
211150
0.0,
212151
BV.zeros[Double](size),
213152
BV.fill(size)(Double.MinValue),
214153
BV.fill(size)(Double.MaxValue))
215154

216-
val VectorRDDStatisticalAggregator(currMean, currM2n, totalCnt, nnz, currMax, currMin) =
217-
self.map(_.toBreeze).aggregate(zeroValue)(seqOp, combOp)
218-
219-
// solve real mean
220-
val realMean = currMean :* nnz :/ totalCnt
221-
222-
// solve real m2n
223-
val deltaMean = currMean
224-
val realM2n = currM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
225-
226-
// remove the initial value in max and min, i.e. the Double.MaxValue or Double.MinValue.
227-
nnz.activeIterator.foreach {
228-
case (id, 0.0) =>
229-
currMax(id) = 0.0
230-
currMin(id) = 0.0
231-
case _ =>
232-
}
233-
234-
// get variance
235-
realM2n :/= totalCnt
236-
237-
VectorRDDStatisticalAggregator(
238-
realMean,
239-
realM2n,
240-
totalCnt,
241-
nnz,
242-
currMax,
243-
currMin)
155+
self.map(_.toBreeze).aggregate[Aggregator](zeroValue)(
156+
(aggregator, data) => aggregator.add(data),
157+
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
158+
)
244159
}
245160
}

mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,19 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
4444

4545
test("full-statistics") {
4646
val data = sc.parallelize(localData, 2)
47-
val (VectorRDDStatisticalAggregator(mean, variance, cnt, nnz, max, min), denseTime) =
47+
val (summary, denseTime) =
4848
time(data.summarizeStatistics())
4949

50-
assert(equivVector(Vectors.fromBreeze(mean), Vectors.dense(4.0, 5.0, 6.0)),
50+
assert(equivVector(summary.mean(), Vectors.dense(4.0, 5.0, 6.0)),
5151
"Column mean do not match.")
52-
assert(equivVector(Vectors.fromBreeze(variance), Vectors.dense(6.0, 6.0, 6.0)),
52+
assert(equivVector(summary.variance(), Vectors.dense(6.0, 6.0, 6.0)),
5353
"Column variance do not match.")
54-
assert(cnt === 3.0, "Column cnt do not match.")
55-
assert(equivVector(Vectors.fromBreeze(nnz), Vectors.dense(3.0, 3.0, 3.0)),
54+
assert(summary.totalCount() === 3, "Column cnt do not match.")
55+
assert(equivVector(summary.numNonZeros(), Vectors.dense(3.0, 3.0, 3.0)),
5656
"Column nnz do not match.")
57-
assert(equivVector(Vectors.fromBreeze(max), Vectors.dense(7.0, 8.0, 9.0)),
57+
assert(equivVector(summary.max(), Vectors.dense(7.0, 8.0, 9.0)),
5858
"Column max do not match.")
59-
assert(equivVector(Vectors.fromBreeze(min), Vectors.dense(1.0, 2.0, 3.0)),
59+
assert(equivVector(summary.min(), Vectors.dense(1.0, 2.0, 3.0)),
6060
"Column min do not match.")
6161

6262
val dataForSparse = sc.parallelize(sparseData.toSeq, 2)

0 commit comments

Comments
 (0)