Skip to content

Commit c627de3

Browse files
committed
add tests for implicit feedback
1 parent b84f41c commit c627de3

File tree

1 file changed

+106
-16
lines changed

1 file changed

+106
-16
lines changed

mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala

Lines changed: 106 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
215215
}
216216

217217
/**
218-
* Generates ratings for testing ALS.
218+
* Generates an explicit feedback dataset for testing ALS.
219219
*
220220
* @param numUsers number of users
221221
* @param numItems number of items
@@ -226,7 +226,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
226226
* @param seed random seed
227227
* @return (training, test)
228228
*/
229-
def genALSTestData(
229+
def genExplicitTestData(
230230
numUsers: Int,
231231
numItems: Int,
232232
rank: Int,
@@ -242,9 +242,9 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
242242
val training = ArrayBuffer.empty[Rating]
243243
val test = ArrayBuffer.empty[Rating]
244244
for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) {
245+
val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1)
245246
val x = random.nextDouble()
246247
if (x < totalFraction) {
247-
val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1)
248248
if (x < trainingFraction) {
249249
val noise = noiseLevel * random.nextGaussian()
250250
training += Rating(userId, itemId, rating + noise.toFloat)
@@ -257,13 +257,75 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
257257
(sc.parallelize(training, 2), sc.parallelize(test, 2))
258258
}
259259

260-
private def genFactors(size: Int, rank: Int, random: Random): Seq[(Int, Array[Float])] = {
260+
/**
261+
* Generates an implicit feedback dataset for testing ALS.
262+
* @param numUsers number of users
263+
* @param numItems number of items
264+
* @param rank rank
265+
* @param noiseLevel standard deviation of Gaussian noise on training data
266+
* @param seed random seed
267+
* @return (training, test)
268+
*/
269+
def genImplicitTestData(
270+
numUsers: Int,
271+
numItems: Int,
272+
rank: Int,
273+
noiseLevel: Double = 0.0,
274+
seed: Long = 11L): (RDD[Rating], RDD[Rating]) = {
275+
// The assumption of the implicit feedback model is that unobserved ratings are more likely to
276+
// be negatives.
277+
val positiveFraction = 0.8
278+
val negativeFraction = 1.0 - positiveFraction
279+
val trainingFraction = 0.6
280+
val testFraction = 0.3
281+
val totalFraction = trainingFraction + testFraction
282+
val random = new Random(seed)
283+
val userFactors = genFactors(numUsers, rank, random)
284+
val itemFactors = genFactors(numItems, rank, random)
285+
val training = ArrayBuffer.empty[Rating]
286+
val test = ArrayBuffer.empty[Rating]
287+
for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) {
288+
val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1)
289+
val threshold = if (rating > 0) positiveFraction else negativeFraction
290+
val observed = random.nextDouble() < threshold
291+
if (observed) {
292+
val x = random.nextDouble()
293+
if (x < trainingFraction) {
294+
val noise = noiseLevel * random.nextGaussian()
295+
training += Rating(userId, itemId, rating + noise.toFloat)
296+
} else if (x < totalFraction) {
297+
test += Rating(userId, itemId, rating)
298+
}
299+
}
300+
}
301+
logInfo(s"Generated an implicit feedback dataset with ${training.size} ratings for training " +
302+
s"and ${test.size} for test.")
303+
(sc.parallelize(training, 2), sc.parallelize(test, 2))
304+
}
305+
306+
/**
307+
* Generates random user/item factors, with i.i.d. values drawn from U(a, b).
308+
* @param size number of users/items
309+
* @param rank number of features
310+
* @param random random number generator
311+
* @param a min value of the support (default: -1)
312+
* @param b max value of the support (default: 1)
313+
* @return a sequence of (ID, factors) pairs
314+
*/
315+
private def genFactors(
316+
size: Int,
317+
rank: Int,
318+
random: Random,
319+
a: Float = -1.0f,
320+
b: Float = 1.0f): Seq[(Int, Array[Float])] = {
261321
require(size > 0 && size < Int.MaxValue / 3)
322+
require(b > a)
262323
val ids = mutable.Set.empty[Int]
263324
while (ids.size < size) {
264325
ids += random.nextInt()
265326
}
266-
ids.toSeq.sorted.map(id => (id, Array.fill(rank)(random.nextFloat())))
327+
val width = b - a
328+
ids.toSeq.sorted.map(id => (id, Array.fill(rank)(a + random.nextFloat() * width)))
267329
}
268330

269331
def testALS(
@@ -272,6 +334,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
272334
rank: Int,
273335
maxIter: Int,
274336
regParam: Double,
337+
implicitPrefs: Boolean = false,
338+
alpha: Double = 1.0,
275339
targetRMSE: Double,
276340
numUserBlocks: Int = 2,
277341
numItemBlocks: Int = 3): Unit = {
@@ -280,43 +344,62 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
280344
val als = new ALS()
281345
.setRank(rank)
282346
.setRegParam(regParam)
347+
.setImplicitPrefs(implicitPrefs)
348+
.setAlpha(alpha)
283349
.setNumUserBlocks(numUserBlocks)
284350
.setNumItemBlocks(numItemBlocks)
285351
val model = als.fit(training)
286-
val prediction = model.transform(test)
287-
val mse = prediction.select('rating, 'prediction)
352+
val predictions = model.transform(test)
353+
.select('rating, 'prediction)
288354
.map { case Row(rating: Float, prediction: Float) =>
289-
val err = rating.toDouble - prediction
290-
err * err
291-
}.mean()
292-
val rmse = math.sqrt(mse)
355+
(rating.toDouble, prediction.toDouble)
356+
}
357+
val rmse =
358+
if (implicitPrefs) {
359+
val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) =>
360+
val confidence = 1.0 + alpha * math.abs(rating)
361+
val rating01 = math.max(math.min(rating, 1.0), 0.0)
362+
val prediction01 = math.max(math.min(prediction, 1.0), 0.0)
363+
val err = prediction01 - rating01
364+
(confidence, confidence * err * err)
365+
}.reduce { case ((c0, e0), (c1, e1)) =>
366+
(c0 + c1, e0 + e1)
367+
}
368+
math.sqrt(weightedSumSq / totalWeight)
369+
} else {
370+
val mse = predictions.map { case (rating, prediction) =>
371+
val err = rating - prediction
372+
err * err
373+
}.mean()
374+
math.sqrt(mse)
375+
}
293376
logInfo(s"Test RMSE is $rmse.")
294377
assert(rmse < targetRMSE)
295378
}
296379

297380
test("exact rank-1 matrix") {
298-
val (training, test) = genALSTestData(numUsers = 20, numItems = 40, rank = 1,
381+
val (training, test) = genExplicitTestData(numUsers = 20, numItems = 40, rank = 1,
299382
trainingFraction = 0.6, testFraction = 0.3)
300383
testALS(training, test, maxIter = 1, rank = 1, regParam = 1e-5, targetRMSE = 0.001)
301384
testALS(training, test, maxIter = 1, rank = 2, regParam = 1e-5, targetRMSE = 0.001)
302385
}
303386

304387
test("approximate rank-1 matrix") {
305-
val (training, test) = genALSTestData(numUsers = 20, numItems = 40, rank = 1,
388+
val (training, test) = genExplicitTestData(numUsers = 20, numItems = 40, rank = 1,
306389
trainingFraction = 0.6, testFraction = 0.3, noiseLevel = 0.01)
307390
testALS(training, test, maxIter = 2, rank = 1, regParam = 0.01, targetRMSE = 0.02)
308391
testALS(training, test, maxIter = 2, rank = 2, regParam = 0.01, targetRMSE = 0.02)
309392
}
310393

311394
test("approximate rank-2 matrix") {
312-
val (training, test) = genALSTestData(numUsers = 20, numItems = 40, rank = 2,
395+
val (training, test) = genExplicitTestData(numUsers = 20, numItems = 40, rank = 2,
313396
trainingFraction = 0.6, testFraction = 0.3, noiseLevel = 0.01)
314397
testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03)
315398
testALS(training, test, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03)
316399
}
317400

318401
test("different block settings") {
319-
val (training, test) = genALSTestData(numUsers = 20, numItems = 40, rank = 2,
402+
val (training, test) = genExplicitTestData(numUsers = 20, numItems = 40, rank = 2,
320403
trainingFraction = 0.6, testFraction = 0.3, noiseLevel = 0.01)
321404
for ((numUserBlocks, numItemBlocks) <- Seq((1, 1), (1, 2), (2, 1), (2, 2))) {
322405
testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03,
@@ -325,9 +408,16 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
325408
}
326409

327410
test("more blocks than ratings") {
328-
val (training, test) = genALSTestData(numUsers = 4, numItems = 4, rank = 1,
411+
val (training, test) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1,
329412
trainingFraction = 0.7, testFraction = 0.3)
330413
testALS(training, test, maxIter = 2, rank = 1, regParam = 1e-4, targetRMSE = 0.002,
331414
numItemBlocks = 5, numUserBlocks = 5)
332415
}
416+
417+
test("implicit feedback") {
418+
val (training, test) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2,
419+
noiseLevel = 0.01)
420+
testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, implicitPrefs = true,
421+
targetRMSE = 0.3)
422+
}
333423
}

0 commit comments

Comments
 (0)