@@ -46,57 +46,49 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging {
4646 * correlation between column i and j.
4747 */
4848 override def computeCorrelationMatrix (X : RDD [Vector ]): Matrix = {
49- // ((columnIndex, value), rowId )
49+ // ((columnIndex, value), rowUid )
5050 val colBased = X .zipWithUniqueId().flatMap { case (vec, uid) =>
5151 vec.toArray.view.zipWithIndex.map { case (v, j) =>
5252 ((j, v), uid)
5353 }
5454 }
5555 // global sort by (columnIndex, value)
5656 val sorted = colBased.sortByKey()
57- // Assign global ranks (using average ranks for tied values)
57+ // assign global ranks (using average ranks for tied values)
5858 val globalRanks = sorted.zipWithIndex().mapPartitions { iter =>
5959 var preCol = - 1
6060 var preVal = Double .NaN
6161 var startRank = - 1.0
62- var cachedIds = ArrayBuffer .empty[Long ]
63- def flush : () => Iterable [(Long , (Int , Double ))] = () => {
64- val averageRank = startRank + (cachedIds .size - 1 ) / 2.0
65- val output = cachedIds .map { i =>
66- (i , (preCol, averageRank))
62+ var cachedUids = ArrayBuffer .empty[Long ]
63+ val flush : () => Iterable [(Long , (Int , Double ))] = () => {
64+ val averageRank = startRank + (cachedUids .size - 1 ) / 2.0
65+ val output = cachedUids .map { uid =>
66+ (uid , (preCol, averageRank))
6767 }
68- cachedIds .clear()
68+ cachedUids .clear()
6969 output
7070 }
7171 iter.flatMap { case (((j, v), uid), rank) =>
72- // If we see a new value or cachedIds is too big, we flush ids with their average rank.
73- if (j != preCol || v != preVal || cachedIds .size >= 10000000 ) {
72+ // If we see a new value or cachedUids is too big, we flush ids with their average rank.
73+ if (j != preCol || v != preVal || cachedUids .size >= 10000000 ) {
7474 val output = flush()
7575 preCol = j
7676 preVal = v
7777 startRank = rank
78- cachedIds += uid
78+ cachedUids += uid
7979 output
8080 } else {
81- cachedIds += uid
81+ cachedUids += uid
8282 Iterator .empty
8383 }
84- } ++ {
85- flush()
86- }
84+ } ++ flush()
8785 }
8886 // Replace values in the input matrix by their ranks compared with values in the same column.
8987 // Note that shifting all ranks in a column by a constant value doesn't affect result.
9088 val groupedRanks = globalRanks.groupByKey().map { case (uid, iter) =>
9189 // sort by column index and then convert values to a vector
9290 Vectors .dense(iter.toSeq.sortBy(_._1).map(_._2).toArray)
9391 }
94- val corrMatrix = PearsonCorrelation .computeCorrelationMatrix(groupedRanks)
95-
96- colBased.unpersist(blocking = false )
97- sorted.unpersist(blocking = false )
98-
99- corrMatrix
92+ PearsonCorrelation .computeCorrelationMatrix(groupedRanks)
10093 }
10194}
102-
0 commit comments