Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,14 @@ private[ml] trait HasOutputCol extends Params {
/** @group getParam */
def getOutputCol: String = get(outputCol)
}

private[ml] trait HasCheckpointInterval extends Params {
/**
* param for checkpoint interval
* @group param
*/
val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval")

/** @group getParam */
def getCheckpointInterval: Int = get(checkpointInterval)
}
42 changes: 37 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.recommendation

import java.{util => ju}
import java.io.IOException

import scala.collection.mutable
import scala.reflect.ClassTag
Expand All @@ -26,6 +27,7 @@ import scala.util.hashing.byteswap64

import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.netlib.util.intW

import org.apache.spark.{Logging, Partitioner}
Expand All @@ -46,7 +48,7 @@ import org.apache.spark.util.random.XORShiftRandom
* Common params for ALS.
*/
private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
with HasPredictionCol {
with HasPredictionCol with HasCheckpointInterval {

/**
* Param for rank of the matrix factorization.
Expand Down Expand Up @@ -164,6 +166,7 @@ class ALSModel private[ml] (
itemFactors: RDD[(Int, Array[Float])])
extends Model[ALSModel] with ALSParams {

/** @group setParam */
def setPredictionCol(value: String): this.type = set(predictionCol, value)

override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
Expand Down Expand Up @@ -262,6 +265,9 @@ class ALS extends Estimator[ALSModel] with ALSParams {
/** @group setParam */
def setNonnegative(value: Boolean): this.type = set(nonnegative, value)

/** @group setParam */
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)

/**
* Sets both numUserBlocks and numItemBlocks to the specific value.
* @group setParam
Expand All @@ -274,6 +280,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {

setMaxIter(20)
setRegParam(1.0)
setCheckpointInterval(10)

override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
val map = this.paramMap ++ paramMap
Expand All @@ -285,7 +292,8 @@ class ALS extends Estimator[ALSModel] with ALSParams {
val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs),
alpha = map(alpha), nonnegative = map(nonnegative))
alpha = map(alpha), nonnegative = map(nonnegative),
checkpointInterval = map(checkpointInterval))
val model = new ALSModel(this, map, map(rank), userFactors, itemFactors)
Params.inheritValues(map, this, model)
model
Expand Down Expand Up @@ -494,6 +502,7 @@ object ALS extends Logging {
nonnegative: Boolean = false,
intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
checkpointInterval: Int = 10,
seed: Long = 0L)(
implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
require(intermediateRDDStorageLevel != StorageLevel.NONE,
Expand Down Expand Up @@ -521,26 +530,49 @@ object ALS extends Logging {
val seedGen = new XORShiftRandom(seed)
var userFactors = initialize(userInBlocks, rank, seedGen.nextLong())
var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong())
var previousCheckpointFile: Option[String] = None
val shouldCheckpoint: Int => Boolean = (iter) =>
sc.checkpointDir.isDefined && (iter % checkpointInterval == 0)
val deletePreviousCheckpointFile: () => Unit = () =>
previousCheckpointFile.foreach { file =>
try {
FileSystem.get(sc.hadoopConfiguration).delete(new Path(file), true)
} catch {
case e: IOException =>
logWarning(s"Cannot delete checkpoint file $file:", e)
}
}
if (implicitPrefs) {
for (iter <- 1 to maxIter) {
userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel)
val previousItemFactors = itemFactors
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, implicitPrefs, alpha, solver)
previousItemFactors.unpersist()
if (sc.checkpointDir.isDefined && (iter % 3 == 0)) {
itemFactors.checkpoint()
}
itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
// TODO: Generalize PeriodicGraphCheckpointer and use it here.
if (shouldCheckpoint(iter)) {
itemFactors.checkpoint() // itemFactors gets materialized in computeFactors.
}
val previousUserFactors = userFactors
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, implicitPrefs, alpha, solver)
if (shouldCheckpoint(iter)) {
deletePreviousCheckpointFile()
previousCheckpointFile = itemFactors.getCheckpointFile
}
previousUserFactors.unpersist()
}
} else {
for (iter <- 0 until maxIter) {
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, solver = solver)
if (shouldCheckpoint(iter)) {
itemFactors.checkpoint()
itemFactors.count() // checkpoint item factors and cut lineage
deletePreviousCheckpointFile()
previousCheckpointFile = itemFactors.getCheckpointFile
}
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, solver = solver)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class ALS private (
private var intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
private var finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK

/** checkpoint interval */
private var checkpointInterval: Int = 10

/**
* Set the number of blocks for both user blocks and product blocks to parallelize the computation
* into; pass -1 for an auto-configured number of blocks. Default: -1.
Expand Down Expand Up @@ -182,6 +185,19 @@ class ALS private (
this
}

/**
* Set period (in iterations) between checkpoints (default = 10). Checkpointing helps with
* recovery (when nodes fail) and StackOverflow exceptions caused by long lineage. It also helps
* with eliminating temporary shuffle files on disk, which can be important when there are many
* ALS iterations. If the checkpoint directory is not set in [[org.apache.spark.SparkContext]],
* this setting is ignored.
*/
@DeveloperApi
def setCheckpointInterval(checkpointInterval: Int): this.type = {
this.checkpointInterval = checkpointInterval
this
}

/**
* Run ALS with the configured parameters on an input RDD of (user, product, rating) triples.
* Returns a MatrixFactorizationModel with feature vectors for each user and product.
Expand Down Expand Up @@ -212,6 +228,7 @@ class ALS private (
nonnegative = nonnegative,
intermediateRDDStorageLevel = intermediateRDDStorageLevel,
finalRDDStorageLevel = StorageLevel.NONE,
checkpointInterval = checkpointInterval,
seed = seed)

val userFactors = floatUserFactors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.recommendation

import java.io.File
import java.util.Random

import scala.collection.mutable
Expand All @@ -32,16 +33,25 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.util.Utils

class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {

private var sqlContext: SQLContext = _
private var tempDir: File = _

override def beforeAll(): Unit = {
super.beforeAll()
tempDir = Utils.createTempDir()
sc.setCheckpointDir(tempDir.getAbsolutePath)
sqlContext = new SQLContext(sc)
}

override def afterAll(): Unit = {
Utils.deleteRecursively(tempDir)
super.afterAll()
}

test("LocalIndexEncoder") {
val random = new Random
for (numBlocks <- Seq(1, 2, 5, 10, 20, 50, 100)) {
Expand Down Expand Up @@ -485,4 +495,11 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
}.count()
}
}

test("als with large number of iterations") {
val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2)
ALS.train(
ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true)
}
}