1717
1818package org .apache .spark .mllib .fpm
1919
20+ import scala .collection .mutable
21+
2022import org .apache .spark .Logging
21- import org .apache .spark .annotation .Experimental
2223
2324/**
24- *
25- * :: Experimental ::
26- *
2725 * Calculate all patterns of a projected database in local.
2826 */
29- @ Experimental
3027private [fpm] object LocalPrefixSpan extends Logging with Serializable {
3128
3229 /**
@@ -43,18 +40,18 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
4340 minCount : Long ,
4441 maxPatternLength : Int ,
4542 prefix : List [Int ],
46- database : Iterable [Array [Int ]]): Iterator [(Array [Int ], Long )] = {
43+ database : Array [Array [Int ]]): Iterator [(List [Int ], Long )] = {
4744
4845 if (database.isEmpty) return Iterator .empty
4946
5047 val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
5148 val frequentItems = frequentItemAndCounts.map(_._1).toSet
5249 val frequentPatternAndCounts = frequentItemAndCounts
53- .map { case (item, count) => ((item :: prefix).reverse.toArray , count) }
50+ .map { case (item, count) => ((item :: prefix), count) }
5451
55- val filteredProjectedDatabase = database.map(x => x.filter(frequentItems.contains(_)))
5652
5753 if (prefix.length + 1 < maxPatternLength) {
54+ val filteredProjectedDatabase = database.map(x => x.filter(frequentItems.contains(_)))
5855 frequentPatternAndCounts.iterator ++ frequentItems.flatMap { item =>
5956 val nextProjected = project(filteredProjectedDatabase, item)
6057 run(minCount, maxPatternLength, item :: prefix, nextProjected)
@@ -79,7 +76,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
7976 }
8077 }
8178
82- def project (database : Iterable [Array [Int ]], prefix : Int ): Iterable [Array [Int ]] = {
79+ def project (database : Array [Array [Int ]], prefix : Int ): Array [Array [Int ]] = {
8380 database
8481 .map(candidateSeq => getSuffix(prefix, candidateSeq))
8582 .filter(_.nonEmpty)
@@ -93,10 +90,11 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
9390 */
9491 private def getFreqItemAndCounts (
9592 minCount : Long ,
96- database : Iterable [Array [Int ]]): Iterable [(Int , Long )] = {
93+ database : Array [Array [Int ]]): Iterable [(Int , Long )] = {
9794 database.flatMap(_.distinct)
98- .foldRight(Map [Int , Long ]().withDefaultValue(0L )) { case (item, ctr) =>
99- ctr + (item -> (ctr(item) + 1 ))
95+ .foldRight(mutable.Map [Int , Long ]().withDefaultValue(0L )) { case (item, ctr) =>
96+ ctr(item) += 1
97+ ctr
10098 }
10199 .filter(_._2 >= minCount)
102100 }
0 commit comments