Skip to content

Commit 54c7b85

Browse files
MechCoderjeanlyn
authored andcommitted
[SPARK-6257] [PYSPARK] [MLLIB] MLlib API missing items in Recommendation
Adds rank, recommendUsers and RecommendProducts to MatrixFactorizationModel in PySpark. Author: MechCoder <[email protected]> Closes apache#5807 from MechCoder/spark-6257 and squashes the following commits: 09629c6 [MechCoder] doc 953b326 [MechCoder] [SPARK-6257] MLlib API missing items in Recommendation
1 parent 22d9b93 commit 54c7b85

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

docs/mllib-collaborative-filtering.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ model = ALS.train(ratings, rank, numIterations)
216216
testdata = ratings.map(lambda p: (p[0], p[1]))
217217
predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2]))
218218
ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions)
219-
MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y) / ratesAndPreds.count()
219+
MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).mean()
220220
print("Mean Squared Error = " + str(MSE))
221221

222222
# Save and load model

python/pyspark/mllib/recommendation.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
6565
>>> model.userFeatures().collect()
6666
[(1, array('d', [...])), (2, array('d', [...]))]
6767
68+
>>> model.recommendUsers(1, 2)
69+
[Rating(user=2, product=1, rating=1.9...), Rating(user=1, product=1, rating=1.0...)]
70+
>>> model.recommendProducts(1, 2)
71+
[Rating(user=1, product=2, rating=1.9...), Rating(user=1, product=1, rating=1.0...)]
72+
>>> model.rank
73+
4
74+
6875
>>> first_user = model.userFeatures().take(1)[0]
6976
>>> latents = first_user[1]
7077
>>> len(latents) == 4
@@ -105,21 +112,53 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
105112
... pass
106113
"""
107114
def predict(self, user, product):
115+
"""
116+
Predicts rating for the given user and product.
117+
"""
108118
return self._java_model.predict(int(user), int(product))
109119

110120
def predictAll(self, user_product):
121+
"""
122+
Returns a list of predicted ratings for input user and product pairs.
123+
"""
111124
assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)"
112125
first = user_product.first()
113126
assert len(first) == 2, "user_product should be RDD of (user, product)"
114127
user_product = user_product.map(lambda u_p: (int(u_p[0]), int(u_p[1])))
115128
return self.call("predict", user_product)
116129

117130
def userFeatures(self):
131+
"""
132+
Returns a paired RDD, where the first element is the user and the
133+
second is an array of features corresponding to that user.
134+
"""
118135
return self.call("getUserFeatures").mapValues(lambda v: array.array('d', v))
119136

120137
def productFeatures(self):
138+
"""
139+
Returns a paired RDD, where the first element is the product and the
140+
second is an array of features corresponding to that product.
141+
"""
121142
return self.call("getProductFeatures").mapValues(lambda v: array.array('d', v))
122143

144+
def recommendUsers(self, product, num):
145+
"""
146+
Recommends the top "num" number of users for a given product and returns a list
147+
of Rating objects sorted by the predicted rating in descending order.
148+
"""
149+
return list(self.call("recommendUsers", product, num))
150+
151+
def recommendProducts(self, user, num):
152+
"""
153+
Recommends the top "num" number of products for a given user and returns a list
154+
of Rating objects sorted by the predicted rating in descending order.
155+
"""
156+
return list(self.call("recommendProducts", user, num))
157+
158+
@property
159+
def rank(self):
160+
return self.call("rank")
161+
123162
@classmethod
124163
def load(cls, sc, path):
125164
model = cls._load_java(sc, path)

0 commit comments

Comments
 (0)