@@ -65,6 +65,13 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
6565 >>> model.userFeatures().collect()
6666 [(1, array('d', [...])), (2, array('d', [...]))]
6767
68+ >>> model.recommendUsers(1, 2)
69+ [Rating(user=2, product=1, rating=1.9...), Rating(user=1, product=1, rating=1.0...)]
70+ >>> model.recommendProducts(1, 2)
71+ [Rating(user=1, product=2, rating=1.9...), Rating(user=1, product=1, rating=1.0...)]
72+ >>> model.rank
73+ 4
74+
6875 >>> first_user = model.userFeatures().take(1)[0]
6976 >>> latents = first_user[1]
7077 >>> len(latents) == 4
@@ -105,21 +112,53 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
105112 ... pass
106113 """
107114 def predict (self , user , product ):
115+ """
116+ Predicts rating for the given user and product.
117+ """
108118 return self ._java_model .predict (int (user ), int (product ))
109119
110120 def predictAll (self , user_product ):
121+ """
122+ Returns a list of predicted ratings for input user and product pairs.
123+ """
111124 assert isinstance (user_product , RDD ), "user_product should be RDD of (user, product)"
112125 first = user_product .first ()
113126 assert len (first ) == 2 , "user_product should be RDD of (user, product)"
114127 user_product = user_product .map (lambda u_p : (int (u_p [0 ]), int (u_p [1 ])))
115128 return self .call ("predict" , user_product )
116129
117130 def userFeatures (self ):
131+ """
132+ Returns a paired RDD, where the first element is the user and the
133+ second is an array of features corresponding to that user.
134+ """
118135 return self .call ("getUserFeatures" ).mapValues (lambda v : array .array ('d' , v ))
119136
120137 def productFeatures (self ):
138+ """
139+ Returns a paired RDD, where the first element is the product and the
140+ second is an array of features corresponding to that product.
141+ """
121142 return self .call ("getProductFeatures" ).mapValues (lambda v : array .array ('d' , v ))
122143
144+ def recommendUsers (self , product , num ):
145+ """
146+ Recommends the top "num" number of users for a given product and returns a list
147+ of Rating objects sorted by the predicted rating in descending order.
148+ """
149+ return list (self .call ("recommendUsers" , product , num ))
150+
151+ def recommendProducts (self , user , num ):
152+ """
153+ Recommends the top "num" number of products for a given user and returns a list
154+ of Rating objects sorted by the predicted rating in descending order.
155+ """
156+ return list (self .call ("recommendProducts" , user , num ))
157+
158+ @property
159+ def rank (self ):
160+ return self .call ("rank" )
161+
123162 @classmethod
124163 def load (cls , sc , path ):
125164 model = cls ._load_java (sc , path )
0 commit comments