Skip to content

Commit 34be542

Browse files
brkyvzjeanlyn
authored andcommitted
[SPARK-7488] [ML] Feature Parity in PySpark for ml.recommendation
Adds Python Api for `ALS` under `ml.recommendation` in PySpark. Also adds seed as a settable parameter in the Scala Implementation of ALS. Author: Burak Yavuz <[email protected]> Closes apache#6015 from brkyvz/ml-rec and squashes the following commits: be6e931 [Burak Yavuz] addressed comments eaed879 [Burak Yavuz] readd numFeatures 0bd66b1 [Burak Yavuz] fixed seed 7f6d964 [Burak Yavuz] merged master 52e2bda [Burak Yavuz] added ALS
1 parent 261167f commit 34be542

File tree

4 files changed

+318
-4
lines changed

4 files changed

+318
-4
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ import org.apache.spark.util.random.XORShiftRandom
4949
* Common params for ALS.
5050
*/
5151
private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
52-
with HasPredictionCol with HasCheckpointInterval {
52+
with HasPredictionCol with HasCheckpointInterval with HasSeed {
5353

5454
/**
5555
* Param for rank of the matrix factorization (>= 1).
@@ -147,7 +147,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
147147

148148
setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
149149
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
150-
ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10)
150+
ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10, seed -> 0L)
151151

152152
/**
153153
* Validates and transforms the input schema.
@@ -278,6 +278,9 @@ class ALS extends Estimator[ALSModel] with ALSParams {
278278
/** @group setParam */
279279
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
280280

281+
/** @group setParam */
282+
def setSeed(value: Long): this.type = set(seed, value)
283+
281284
/**
282285
* Sets both numUserBlocks and numItemBlocks to the specific value.
283286
* @group setParam
@@ -290,15 +293,16 @@ class ALS extends Estimator[ALSModel] with ALSParams {
290293

291294
override def fit(dataset: DataFrame): ALSModel = {
292295
val ratings = dataset
293-
.select(col($(userCol)), col($(itemCol)), col($(ratingCol)).cast(FloatType))
296+
.select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType),
297+
col($(ratingCol)).cast(FloatType))
294298
.map { row =>
295299
Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
296300
}
297301
val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank),
298302
numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks),
299303
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
300304
alpha = $(alpha), nonnegative = $(nonnegative),
301-
checkpointInterval = $(checkpointInterval))
305+
checkpointInterval = $(checkpointInterval), seed = $(seed))
302306
copyValues(new ALSModel(this, $(rank), userFactors, itemFactors))
303307
}
304308

python/pyspark/ml/param/_shared_params_code_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ def get$Name(self):
9797
("inputCol", "input column name", None),
9898
("inputCols", "input column names", None),
9999
("outputCol", "output column name", None),
100+
("numFeatures", "number of features", None),
101+
("checkpointInterval", "checkpoint interval (>= 1)", None),
100102
("seed", "random seed", None),
101103
("tol", "the convergence tolerance for iterative algorithms", None),
102104
("stepSize", "Step size to be used for each iteration of optimization.", None)]

python/pyspark/ml/param/shared.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,35 @@ def getNumFeatures(self):
310310
return self.getOrDefault(self.numFeatures)
311311

312312

313+
class HasCheckpointInterval(Params):
314+
"""
315+
Mixin for param checkpointInterval: checkpoint interval (>= 1).
316+
"""
317+
318+
# a placeholder to make it appear in the generated doc
319+
checkpointInterval = Param(Params._dummy(), "checkpointInterval", "checkpoint interval (>= 1)")
320+
321+
def __init__(self):
322+
super(HasCheckpointInterval, self).__init__()
323+
#: param for checkpoint interval (>= 1)
324+
self.checkpointInterval = Param(self, "checkpointInterval", "checkpoint interval (>= 1)")
325+
if None is not None:
326+
self._setDefault(checkpointInterval=None)
327+
328+
def setCheckpointInterval(self, value):
329+
"""
330+
Sets the value of :py:attr:`checkpointInterval`.
331+
"""
332+
self.paramMap[self.checkpointInterval] = value
333+
return self
334+
335+
def getCheckpointInterval(self):
336+
"""
337+
Gets the value of checkpointInterval or its default value.
338+
"""
339+
return self.getOrDefault(self.checkpointInterval)
340+
341+
313342
class HasSeed(Params):
314343
"""
315344
Mixin for param seed: random seed.
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from pyspark.ml.util import keyword_only
19+
from pyspark.ml.wrapper import JavaEstimator, JavaModel
20+
from pyspark.ml.param.shared import *
21+
from pyspark.mllib.common import inherit_doc
22+
23+
24+
__all__ = ['ALS', 'ALSModel']
25+
26+
27+
@inherit_doc
28+
class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed):
29+
"""
30+
Alternating Least Squares (ALS) matrix factorization.
31+
32+
ALS attempts to estimate the ratings matrix `R` as the product of
33+
two lower-rank matrices, `X` and `Y`, i.e. `X * Yt = R`. Typically
34+
these approximations are called 'factor' matrices. The general
35+
approach is iterative. During each iteration, one of the factor
36+
matrices is held constant, while the other is solved for using least
37+
squares. The newly-solved factor matrix is then held constant while
38+
solving for the other factor matrix.
39+
40+
This is a blocked implementation of the ALS factorization algorithm
41+
that groups the two sets of factors (referred to as "users" and
42+
"products") into blocks and reduces communication by only sending
43+
one copy of each user vector to each product block on each
44+
iteration, and only for the product blocks that need that user's
45+
feature vector. This is achieved by pre-computing some information
46+
about the ratings matrix to determine the "out-links" of each user
47+
(which blocks of products it will contribute to) and "in-link"
48+
information for each product (which of the feature vectors it
49+
receives from each user block it will depend on). This allows us to
50+
send only an array of feature vectors between each user block and
51+
product block, and have the product block find the users' ratings
52+
and update the products based on these messages.
53+
54+
For implicit preference data, the algorithm used is based on
55+
"Collaborative Filtering for Implicit Feedback Datasets", available
56+
at `http://dx.doi.org/10.1109/ICDM.2008.22`, adapted for the blocked
57+
approach used here.
58+
59+
Essentially instead of finding the low-rank approximations to the
60+
rating matrix `R`, this finds the approximations for a preference
61+
matrix `P` where the elements of `P` are 1 if r > 0 and 0 if r <= 0.
62+
The ratings then act as 'confidence' values related to strength of
63+
indicated user preferences rather than explicit ratings given to
64+
items.
65+
66+
>>> als = ALS(rank=10, maxIter=5)
67+
>>> model = als.fit(df)
68+
>>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])
69+
>>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0])
70+
>>> predictions[0]
71+
Row(user=0, item=2, prediction=0.39...)
72+
>>> predictions[1]
73+
Row(user=1, item=0, prediction=3.19...)
74+
>>> predictions[2]
75+
Row(user=2, item=0, prediction=-1.15...)
76+
"""
77+
_java_class = "org.apache.spark.ml.recommendation.ALS"
78+
# a placeholder to make it appear in the generated doc
79+
rank = Param(Params._dummy(), "rank", "rank of the factorization")
80+
numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks")
81+
numItemBlocks = Param(Params._dummy(), "numItemBlocks", "number of item blocks")
82+
implicitPrefs = Param(Params._dummy(), "implicitPrefs", "whether to use implicit preference")
83+
alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference")
84+
userCol = Param(Params._dummy(), "userCol", "column name for user ids")
85+
itemCol = Param(Params._dummy(), "itemCol", "column name for item ids")
86+
ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings")
87+
nonnegative = Param(Params._dummy(), "nonnegative",
88+
"whether to use nonnegative constraint for least squares")
89+
90+
@keyword_only
91+
def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
92+
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0,
93+
ratingCol="rating", nonnegative=False, checkpointInterval=10):
94+
"""
95+
__init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
96+
implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=0,
97+
ratingCol="rating", nonnegative=false, checkpointInterval=10)
98+
"""
99+
super(ALS, self).__init__()
100+
self.rank = Param(self, "rank", "rank of the factorization")
101+
self.numUserBlocks = Param(self, "numUserBlocks", "number of user blocks")
102+
self.numItemBlocks = Param(self, "numItemBlocks", "number of item blocks")
103+
self.implicitPrefs = Param(self, "implicitPrefs", "whether to use implicit preference")
104+
self.alpha = Param(self, "alpha", "alpha for implicit preference")
105+
self.userCol = Param(self, "userCol", "column name for user ids")
106+
self.itemCol = Param(self, "itemCol", "column name for item ids")
107+
self.ratingCol = Param(self, "ratingCol", "column name for ratings")
108+
self.nonnegative = Param(self, "nonnegative",
109+
"whether to use nonnegative constraint for least squares")
110+
self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
111+
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0,
112+
ratingCol="rating", nonnegative=False, checkpointInterval=10)
113+
kwargs = self.__init__._input_kwargs
114+
self.setParams(**kwargs)
115+
116+
@keyword_only
117+
def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
118+
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0,
119+
ratingCol="rating", nonnegative=False, checkpointInterval=10):
120+
"""
121+
setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
122+
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0,
123+
ratingCol="rating", nonnegative=False, checkpointInterval=10)
124+
Sets params for ALS.
125+
"""
126+
kwargs = self.setParams._input_kwargs
127+
return self._set(**kwargs)
128+
129+
def _create_model(self, java_model):
130+
return ALSModel(java_model)
131+
132+
def setRank(self, value):
133+
"""
134+
Sets the value of :py:attr:`rank`.
135+
"""
136+
self.paramMap[self.rank] = value
137+
return self
138+
139+
def getRank(self):
140+
"""
141+
Gets the value of rank or its default value.
142+
"""
143+
return self.getOrDefault(self.rank)
144+
145+
def setNumUserBlocks(self, value):
146+
"""
147+
Sets the value of :py:attr:`numUserBlocks`.
148+
"""
149+
self.paramMap[self.numUserBlocks] = value
150+
return self
151+
152+
def getNumUserBlocks(self):
153+
"""
154+
Gets the value of numUserBlocks or its default value.
155+
"""
156+
return self.getOrDefault(self.numUserBlocks)
157+
158+
def setNumItemBlocks(self, value):
159+
"""
160+
Sets the value of :py:attr:`numItemBlocks`.
161+
"""
162+
self.paramMap[self.numItemBlocks] = value
163+
return self
164+
165+
def getNumItemBlocks(self):
166+
"""
167+
Gets the value of numItemBlocks or its default value.
168+
"""
169+
return self.getOrDefault(self.numItemBlocks)
170+
171+
def setNumBlocks(self, value):
172+
"""
173+
Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value.
174+
"""
175+
self.paramMap[self.numUserBlocks] = value
176+
self.paramMap[self.numItemBlocks] = value
177+
178+
def setImplicitPrefs(self, value):
179+
"""
180+
Sets the value of :py:attr:`implicitPrefs`.
181+
"""
182+
self.paramMap[self.implicitPrefs] = value
183+
return self
184+
185+
def getImplicitPrefs(self):
186+
"""
187+
Gets the value of implicitPrefs or its default value.
188+
"""
189+
return self.getOrDefault(self.implicitPrefs)
190+
191+
def setAlpha(self, value):
192+
"""
193+
Sets the value of :py:attr:`alpha`.
194+
"""
195+
self.paramMap[self.alpha] = value
196+
return self
197+
198+
def getAlpha(self):
199+
"""
200+
Gets the value of alpha or its default value.
201+
"""
202+
return self.getOrDefault(self.alpha)
203+
204+
def setUserCol(self, value):
205+
"""
206+
Sets the value of :py:attr:`userCol`.
207+
"""
208+
self.paramMap[self.userCol] = value
209+
return self
210+
211+
def getUserCol(self):
212+
"""
213+
Gets the value of userCol or its default value.
214+
"""
215+
return self.getOrDefault(self.userCol)
216+
217+
def setItemCol(self, value):
218+
"""
219+
Sets the value of :py:attr:`itemCol`.
220+
"""
221+
self.paramMap[self.itemCol] = value
222+
return self
223+
224+
def getItemCol(self):
225+
"""
226+
Gets the value of itemCol or its default value.
227+
"""
228+
return self.getOrDefault(self.itemCol)
229+
230+
def setRatingCol(self, value):
231+
"""
232+
Sets the value of :py:attr:`ratingCol`.
233+
"""
234+
self.paramMap[self.ratingCol] = value
235+
return self
236+
237+
def getRatingCol(self):
238+
"""
239+
Gets the value of ratingCol or its default value.
240+
"""
241+
return self.getOrDefault(self.ratingCol)
242+
243+
def setNonnegative(self, value):
244+
"""
245+
Sets the value of :py:attr:`nonnegative`.
246+
"""
247+
self.paramMap[self.nonnegative] = value
248+
return self
249+
250+
def getNonnegative(self):
251+
"""
252+
Gets the value of nonnegative or its default value.
253+
"""
254+
return self.getOrDefault(self.nonnegative)
255+
256+
257+
class ALSModel(JavaModel):
258+
"""
259+
Model fitted by ALS.
260+
"""
261+
262+
263+
if __name__ == "__main__":
264+
import doctest
265+
from pyspark.context import SparkContext
266+
from pyspark.sql import SQLContext
267+
globs = globals().copy()
268+
# The small batch size here ensures that we see multiple batches,
269+
# even in these small test examples:
270+
sc = SparkContext("local[2]", "ml.recommendation tests")
271+
sqlContext = SQLContext(sc)
272+
globs['sc'] = sc
273+
globs['sqlContext'] = sqlContext
274+
globs['df'] = sqlContext.createDataFrame([(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0),
275+
(2, 1, 1.0), (2, 2, 5.0)], ["user", "item", "rating"])
276+
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
277+
sc.stop()
278+
if failure_count:
279+
exit(-1)

0 commit comments

Comments
 (0)