Skip to content

Commit 70b93e3

Browse files
author
Feynman Liang
committed
Performance improvements in LocalPrefixSpan, fix tests
1 parent 0c5207c commit 70b93e3

File tree

3 files changed

+32
-36
lines changed

3 files changed

+32
-36
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ package org.apache.spark.mllib.fpm
2020
import org.apache.spark.Logging
2121
import org.apache.spark.annotation.Experimental
2222

23+
import scala.collection.mutable.ArrayBuffer
24+
2325
/**
2426
*
2527
* :: Experimental ::
@@ -42,22 +44,20 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
4244
def run(
4345
minCount: Long,
4446
maxPatternLength: Int,
45-
prefix: Array[Int],
46-
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
47+
prefix: ArrayBuffer[Int],
48+
projectedDatabase: Array[Array[Int]]): Iterator[(Array[Int], Long)] = {
4749
val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
4850
val frequentPatternAndCounts = frequentPrefixAndCounts
49-
.map(x => (prefix ++ Array(x._1), x._2))
51+
.map(x => ((prefix :+ x._1).toArray, x._2))
5052
val prefixProjectedDatabases = getPatternAndProjectedDatabase(
5153
prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
5254

53-
val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
54-
if (continueProcess) {
55-
val nextPatterns = prefixProjectedDatabases
56-
.map(x => run(minCount, maxPatternLength, x._1, x._2))
57-
.reduce(_ ++ _)
58-
frequentPatternAndCounts ++ nextPatterns
55+
if (prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength) {
56+
frequentPatternAndCounts.iterator ++ prefixProjectedDatabases.flatMap {
57+
case (nextPrefix, projDB) => run(minCount, maxPatternLength, nextPrefix, projDB)
58+
}
5959
} else {
60-
frequentPatternAndCounts
60+
frequentPatternAndCounts.iterator
6161
}
6262
}
6363

@@ -86,28 +86,30 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
8686
minCount: Long,
8787
sequences: Array[Array[Int]]): Array[(Int, Long)] = {
8888
sequences.flatMap(_.distinct)
89-
.groupBy(x => x)
90-
.mapValues(_.length.toLong)
89+
.foldRight(Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) =>
90+
ctr + (item -> (ctr(item) + 1))
91+
}
9192
.filter(_._2 >= minCount)
9293
.toArray
9394
}
9495

9596
/**
9697
* Get the frequent prefixes' projected database.
97-
* @param prePrefix the frequent prefixes' prefix
98-
* @param frequentPrefixes frequent prefixes
99-
* @param sequences sequences data
100-
* @return prefixes and projected database
98+
* @param prefix the frequent prefixes' prefix
99+
* @param frequentPrefixes frequent next prefixes
100+
* @param projDB projected database for given prefix
101+
* @return extensions of prefix by one item and corresponding projected databases
101102
*/
102103
private def getPatternAndProjectedDatabase(
103-
prePrefix: Array[Int],
104+
prefix: ArrayBuffer[Int],
104105
frequentPrefixes: Array[Int],
105-
sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
106-
val filteredProjectedDatabase = sequences
107-
.map(x => x.filter(frequentPrefixes.contains(_)))
108-
frequentPrefixes.map { x =>
109-
val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
110-
(prePrefix ++ Array(x), sub)
106+
projDB: Array[Array[Int]]): Array[(ArrayBuffer[Int], Array[Array[Int]])] = {
107+
val filteredProjectedDatabase = projDB.map(x => x.filter(frequentPrefixes.contains(_)))
108+
frequentPrefixes.map { nextItem =>
109+
val nextProjDB = filteredProjectedDatabase
110+
.map(candidateSeq => getSuffix(nextItem, candidateSeq))
111+
.filter(_.nonEmpty)
112+
(prefix :+ nextItem, nextProjDB)
111113
}.filter(x => x._2.nonEmpty)
112114
}
113115
}

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import org.apache.spark.annotation.Experimental
2222
import org.apache.spark.rdd.RDD
2323
import org.apache.spark.storage.StorageLevel
2424

25+
import scala.collection.mutable.ArrayBuffer
26+
2527
/**
2628
*
2729
* :: Experimental ::
@@ -150,8 +152,8 @@ class PrefixSpan private (
150152
private def getPatternsInLocal(
151153
minCount: Long,
152154
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
153-
data.flatMap { x =>
154-
LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2)
155+
data.flatMap { case (prefix, projDB) =>
156+
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.to[ArrayBuffer], projDB)
155157
}
156158
}
157159
}

mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ package org.apache.spark.mllib.fpm
1818

1919
import org.apache.spark.SparkFunSuite
2020
import org.apache.spark.mllib.util.MLlibTestSparkContext
21-
import org.apache.spark.rdd.RDD
2221

23-
class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
22+
class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
2423

2524
test("PrefixSpan using Integer type") {
2625

@@ -48,15 +47,8 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
4847
def compareResult(
4948
expectedValue: Array[(Array[Int], Long)],
5049
actualValue: Array[(Array[Int], Long)]): Boolean = {
51-
val sortedExpectedValue = expectedValue.sortWith{ (x, y) =>
52-
x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
53-
}
54-
val sortedActualValue = actualValue.sortWith{ (x, y) =>
55-
x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
56-
}
57-
sortedExpectedValue.zip(sortedActualValue)
58-
.map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2)
59-
.reduce(_&&_)
50+
expectedValue.map(x => (x._1.toList, x._2)).toSet ==
51+
actualValue.map(x => (x._1.toList, x._2)).toSet
6052
}
6153

6254
val prefixspan = new PrefixSpan()

0 commit comments

Comments
 (0)