Skip to content

Commit 1376ff4

Browse files
committed
rename variables and adjust code
1 parent 4a5c38d commit 1376ff4

File tree

2 files changed

+54
-65
lines changed

2 files changed

+54
-65
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 39 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -18,34 +18,20 @@ package org.apache.spark.mllib.rdd
1818

1919
import breeze.linalg.{axpy, Vector => BV}
2020

21-
import org.apache.spark.mllib.linalg.{Vector, Vectors}
21+
import org.apache.spark.mllib.linalg.Vector
2222
import org.apache.spark.rdd.RDD
2323

2424
/**
2525
* Case class of the summary statistics, including mean, variance, count, max, min, and non-zero
2626
* elements count.
2727
*/
28-
case class VectorRDDStatisticalSummary(
29-
mean: Vector,
30-
variance: Vector,
31-
count: Long,
32-
max: Vector,
33-
min: Vector,
34-
nonZeroCnt: Vector) extends Serializable
35-
36-
/**
37-
* Case class of the aggregate value for collecting summary statistics from RDD[Vector]. These
38-
* values are relatively with
39-
* [[org.apache.spark.mllib.rdd.VectorRDDStatisticalSummary VectorRDDStatisticalSummary]], the
40-
* latter is computed from the former.
41-
*/
42-
private case class VectorRDDStatisticalRing(
43-
fakeMean: BV[Double],
44-
fakeM2n: BV[Double],
45-
totalCnt: Double,
46-
nnz: BV[Double],
47-
fakeMax: BV[Double],
48-
fakeMin: BV[Double])
28+
case class VectorRDDStatisticalAggregator(
29+
mean: BV[Double],
30+
statCounter: BV[Double],
31+
totalCount: Double,
32+
numNonZeros: BV[Double],
33+
max: BV[Double],
34+
min: BV[Double])
4935

5036
/**
5137
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an
@@ -58,11 +44,12 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
5844
* Aggregate function used for aggregating elements in a worker together.
5945
*/
6046
private def seqOp(
61-
aggregator: VectorRDDStatisticalRing,
62-
currData: BV[Double]): VectorRDDStatisticalRing = {
47+
aggregator: VectorRDDStatisticalAggregator,
48+
currData: BV[Double]): VectorRDDStatisticalAggregator = {
6349
aggregator match {
64-
case VectorRDDStatisticalRing(prevMean, prevM2n, cnt, nnzVec, maxVec, minVec) =>
50+
case VectorRDDStatisticalAggregator(prevMean, prevM2n, cnt, nnzVec, maxVec, minVec) =>
6551
currData.activeIterator.foreach {
52+
case (id, 0.0) =>
6653
case (id, value) =>
6754
if (maxVec(id) < value) maxVec(id) = value
6855
if (minVec(id) > value) minVec(id) = value
@@ -74,7 +61,7 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
7461
nnzVec(id) += 1.0
7562
}
7663

77-
VectorRDDStatisticalRing(
64+
VectorRDDStatisticalAggregator(
7865
prevMean,
7966
prevM2n,
8067
cnt + 1.0,
@@ -88,11 +75,11 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
8875
* Combine function used for combining intermediate results together from every worker.
8976
*/
9077
private def combOp(
91-
statistics1: VectorRDDStatisticalRing,
92-
statistics2: VectorRDDStatisticalRing): VectorRDDStatisticalRing = {
78+
statistics1: VectorRDDStatisticalAggregator,
79+
statistics2: VectorRDDStatisticalAggregator): VectorRDDStatisticalAggregator = {
9380
(statistics1, statistics2) match {
94-
case (VectorRDDStatisticalRing(mean1, m2n1, cnt1, nnz1, max1, min1),
95-
VectorRDDStatisticalRing(mean2, m2n2, cnt2, nnz2, max2, min2)) =>
81+
case (VectorRDDStatisticalAggregator(mean1, m2n1, cnt1, nnz1, max1, min1),
82+
VectorRDDStatisticalAggregator(mean2, m2n2, cnt2, nnz2, max2, min2)) =>
9683
val totalCnt = cnt1 + cnt2
9784
val deltaMean = mean2 - mean1
9885

@@ -120,51 +107,50 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
120107
}
121108

122109
axpy(1.0, nnz2, nnz1)
123-
VectorRDDStatisticalRing(mean1, m2n1, totalCnt, nnz1, max1, min1)
110+
VectorRDDStatisticalAggregator(mean1, m2n1, totalCnt, nnz1, max1, min1)
124111
}
125112
}
126113

127114
/**
128115
* Compute full column-wise statistics for the RDD with the size of Vector as input parameter.
129116
*/
130-
def summarizeStatistics(size: Int): VectorRDDStatisticalSummary = {
131-
val zeroValue = VectorRDDStatisticalRing(
117+
def summarizeStatistics(): VectorRDDStatisticalAggregator = {
118+
val size = self.take(1).head.size
119+
val zeroValue = VectorRDDStatisticalAggregator(
132120
BV.zeros[Double](size),
133121
BV.zeros[Double](size),
134122
0.0,
135123
BV.zeros[Double](size),
136124
BV.fill(size)(Double.MinValue),
137125
BV.fill(size)(Double.MaxValue))
138126

139-
val VectorRDDStatisticalRing(fakeMean, fakeM2n, totalCnt, nnz, fakeMax, fakeMin) =
127+
val VectorRDDStatisticalAggregator(currMean, currM2n, totalCnt, nnz, currMax, currMin) =
140128
self.map(_.toBreeze).aggregate(zeroValue)(seqOp, combOp)
141129

142130
// solve real mean
143-
val realMean = fakeMean :* nnz :/ totalCnt
131+
val realMean = currMean :* nnz :/ totalCnt
144132

145133
// solve real m2n
146-
val deltaMean = fakeMean
147-
val realM2n = fakeM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
134+
val deltaMean = currMean
135+
val realM2n = currM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
148136

149137
// remove the initial value in max and min, i.e. the Double.MaxValue or Double.MinValue.
150-
val max = Vectors.sparse(size, fakeMax.activeIterator.map { case (id, value) =>
151-
if ((value == Double.MinValue) && (realMean(id) != Double.MinValue)) (id, 0.0)
152-
else (id, value)
153-
}.toSeq)
154-
val min = Vectors.sparse(size, fakeMin.activeIterator.map { case (id, value) =>
155-
if ((value == Double.MaxValue) && (realMean(id) != Double.MaxValue)) (id, 0.0)
156-
else (id, value)
157-
}.toSeq)
138+
nnz.activeIterator.foreach {
139+
case (id, 0.0) =>
140+
currMax(id) = 0.0
141+
currMin(id) = 0.0
142+
case _ =>
143+
}
158144

159145
// get variance
160146
realM2n :/= totalCnt
161147

162-
VectorRDDStatisticalSummary(
163-
Vectors.fromBreeze(realMean),
164-
Vectors.fromBreeze(realM2n),
165-
totalCnt.toLong,
166-
Vectors.fromBreeze(nnz),
167-
max,
168-
min)
148+
VectorRDDStatisticalAggregator(
149+
realMean,
150+
realM2n,
151+
totalCnt,
152+
nnz,
153+
currMax,
154+
currMin)
169155
}
170-
}
156+
}

mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17-
1817
package org.apache.spark.mllib.rdd
1918

2019
import scala.collection.mutable.ArrayBuffer
@@ -45,18 +44,23 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
4544

4645
test("full-statistics") {
4746
val data = sc.parallelize(localData, 2)
48-
val (VectorRDDStatisticalSummary(mean, variance, cnt, nnz, max, min), denseTime) =
49-
time(data.summarizeStatistics(3))
47+
val (VectorRDDStatisticalAggregator(mean, variance, cnt, nnz, max, min), denseTime) =
48+
time(data.summarizeStatistics())
5049

51-
assert(equivVector(mean, Vectors.dense(4.0, 5.0, 6.0)), "Column mean do not match.")
52-
assert(equivVector(variance, Vectors.dense(6.0, 6.0, 6.0)), "Column variance do not match.")
53-
assert(cnt === 3, "Column cnt do not match.")
54-
assert(equivVector(nnz, Vectors.dense(3.0, 3.0, 3.0)), "Column nnz do not match.")
55-
assert(equivVector(max, Vectors.dense(7.0, 8.0, 9.0)), "Column max do not match.")
56-
assert(equivVector(min, Vectors.dense(1.0, 2.0, 3.0)), "Column min do not match.")
50+
assert(equivVector(Vectors.fromBreeze(mean), Vectors.dense(4.0, 5.0, 6.0)),
51+
"Column mean do not match.")
52+
assert(equivVector(Vectors.fromBreeze(variance), Vectors.dense(6.0, 6.0, 6.0)),
53+
"Column variance do not match.")
54+
assert(cnt === 3.0, "Column cnt do not match.")
55+
assert(equivVector(Vectors.fromBreeze(nnz), Vectors.dense(3.0, 3.0, 3.0)),
56+
"Column nnz do not match.")
57+
assert(equivVector(Vectors.fromBreeze(max), Vectors.dense(7.0, 8.0, 9.0)),
58+
"Column max do not match.")
59+
assert(equivVector(Vectors.fromBreeze(min), Vectors.dense(1.0, 2.0, 3.0)),
60+
"Column min do not match.")
5761

5862
val dataForSparse = sc.parallelize(sparseData.toSeq, 2)
59-
val (_, sparseTime) = time(dataForSparse.summarizeStatistics(20))
63+
val (_, sparseTime) = time(dataForSparse.summarizeStatistics())
6064

6165
println(s"dense time is $denseTime, sparse time is $sparseTime.")
6266
assert(relativeTime(denseTime, sparseTime),
@@ -80,5 +84,4 @@ object VectorRDDFunctionsSuite {
8084
val denominator = math.max(lhs, rhs)
8185
math.abs(lhs - rhs) / denominator < 0.3
8286
}
83-
}
84-
87+
}

0 commit comments

Comments
 (0)