Skip to content

Commit 52e2bda

Browse files
committed
added ALS
1 parent c796be7 commit 52e2bda

File tree

4 files changed

+396
-18
lines changed

4 files changed

+396
-18
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ import org.apache.spark.util.random.XORShiftRandom
4949
* Common params for ALS.
5050
*/
5151
private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
52-
with HasPredictionCol with HasCheckpointInterval {
52+
with HasPredictionCol with HasCheckpointInterval with HasSeed {
5353

5454
/**
5555
* Param for rank of the matrix factorization (>= 1).
@@ -147,7 +147,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
147147

148148
setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
149149
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
150-
ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10)
150+
ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10, seed -> 42L)
151151

152152
/**
153153
* Validates and transforms the input schema.
@@ -290,15 +290,16 @@ class ALS extends Estimator[ALSModel] with ALSParams {
290290

291291
override def fit(dataset: DataFrame): ALSModel = {
292292
val ratings = dataset
293-
.select(col($(userCol)), col($(itemCol)), col($(ratingCol)).cast(FloatType))
293+
.select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType),
294+
col($(ratingCol)).cast(FloatType))
294295
.map { row =>
295296
Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
296297
}
297298
val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank),
298299
numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks),
299300
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
300301
alpha = $(alpha), nonnegative = $(nonnegative),
301-
checkpointInterval = $(checkpointInterval))
302+
checkpointInterval = $(checkpointInterval), seed = $(seed))
302303
copyValues(new ALSModel(this, $(rank), userFactors, itemFactors))
303304
}
304305

python/pyspark/ml/param/_shared_params_code_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def get$Name(self):
9797
("inputCol", "input column name", None),
9898
("inputCols", "input column names", None),
9999
("outputCol", "output column name", None),
100-
("numFeatures", "number of features", None)]
100+
("checkpointInterval", "checkpoint interval (>= 1)", None)]
101101
code = []
102102
for name, doc, defaultValueStr in shared:
103103
code.append(_gen_param_code(name, doc, defaultValueStr))

python/pyspark/ml/param/shared.py

Lines changed: 119 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -281,30 +281,136 @@ def getOutputCol(self):
281281
return self.getOrDefault(self.outputCol)
282282

283283

284-
class HasNumFeatures(Params):
284+
285+
class HasNumItemBlocks(Params):
286+
"""
287+
Mixin for param numItemBlocks: number of item blocks.
288+
"""
289+
290+
# a placeholder to make it appear in the generated doc
291+
292+
293+
def __init__(self):
294+
super(HasNumItemBlocks, self).__init__()
295+
#: param for number of item blocks
296+
297+
if None is not None:
298+
self._setDefault(numItemBlocks=None)
299+
300+
301+
302+
class HasImplicitPrefs(Params):
303+
"""
304+
Mixin for param implicitPrefs: whether to use implicit preference.
305+
"""
306+
307+
# a placeholder to make it appear in the generated doc
308+
309+
310+
def __init__(self):
311+
super(HasImplicitPrefs, self).__init__()
312+
#: param for whether to use implicit preference
313+
314+
if None is not None:
315+
self._setDefault(implicitPrefs=None)
316+
317+
318+
319+
320+
class HasAlpha(Params):
321+
"""
322+
Mixin for param alpha: alpha for implicit preference.
323+
"""
324+
325+
# a placeholder to make it appear in the generated doc
326+
327+
328+
def __init__(self):
329+
super(HasAlpha, self).__init__()
330+
#: param for alpha for implicit preference
331+
332+
if None is not None:
333+
self._setDefault(alpha=None)
334+
335+
336+
337+
338+
class HasUserCol(Params):
339+
"""
340+
Mixin for param userCol: column name for user ids.
341+
"""
342+
343+
# a placeholder to make it appear in the generated doc
344+
345+
346+
def __init__(self):
347+
super(HasUserCol, self).__init__()
348+
#: param for column name for user ids
349+
350+
if None is not None:
351+
self._setDefault(userCol=None)
352+
353+
354+
355+
356+
class HasItemCol(Params):
357+
"""
358+
Mixin for param itemCol: column name for item ids.
359+
"""
360+
361+
# a placeholder to make it appear in the generated doc
362+
363+
364+
def __init__(self):
365+
super(HasItemCol, self).__init__()
366+
#: param for column name for item ids
367+
368+
if None is not None:
369+
self._setDefault(itemCol=None)
370+
371+
372+
373+
374+
class HasRatingCol(Params):
375+
"""
376+
Mixin for param ratingCol: column name for ratings.
377+
"""
378+
379+
# a placeholder to make it appear in the generated doc
380+
381+
382+
def __init__(self):
383+
super(HasRatingCol, self).__init__()
384+
#: param for column name for ratings
385+
386+
if None is not None:
387+
self._setDefault(ratingCol=None)
388+
389+
390+
class HasCheckpointInterval(Params):
285391
"""
286-
Mixin for param numFeatures: number of features.
392+
Mixin for param checkpointInterval: checkpoint interval (>= 1).
287393
"""
288394

289395
# a placeholder to make it appear in the generated doc
290-
numFeatures = Param(Params._dummy(), "numFeatures", "number of features")
396+
checkpointInterval = Param(Params._dummy(), "checkpointInterval", "checkpoint interval (>= 1)")
291397

292398
def __init__(self):
293-
super(HasNumFeatures, self).__init__()
294-
#: param for number of features
295-
self.numFeatures = Param(self, "numFeatures", "number of features")
399+
super(HasCheckpointInterval, self).__init__()
400+
#: param for checkpoint interval (>= 1)
401+
self.checkpointInterval = Param(self, "checkpointInterval", "checkpoint interval (>= 1)")
296402
if None is not None:
297-
self._setDefault(numFeatures=None)
403+
self._setDefault(checkpointInterval=None)
298404

299-
def setNumFeatures(self, value):
405+
def setCheckpointInterval(self, value):
300406
"""
301-
Sets the value of :py:attr:`numFeatures`.
407+
Sets the value of :py:attr:`checkpointInterval`.
302408
"""
303-
self.paramMap[self.numFeatures] = value
409+
self.paramMap[self.checkpointInterval] = value
304410
return self
305411

306-
def getNumFeatures(self):
412+
def getCheckpointInterval(self):
307413
"""
308-
Gets the value of numFeatures or its default value.
414+
Gets the value of checkpointInterval or its default value.
309415
"""
310-
return self.getOrDefault(self.numFeatures)
416+
return self.getOrDefault(self.checkpointInterval)

0 commit comments

Comments
 (0)