|
| 1 | +# |
| 2 | +# Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | +# contributor license agreements. See the NOTICE file distributed with |
| 4 | +# this work for additional information regarding copyright ownership. |
| 5 | +# The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | +# (the "License"); you may not use this file except in compliance with |
| 7 | +# the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | + |
| 18 | +from pyspark.ml.util import keyword_only |
| 19 | +from pyspark.ml.wrapper import JavaEstimator, JavaModel |
| 20 | +from pyspark.ml.param.shared import * |
| 21 | +from pyspark.mllib.common import inherit_doc |
| 22 | + |
| 23 | + |
| 24 | +__all__ = ['ALS', 'ALSModel'] |
| 25 | + |
| 26 | + |
| 27 | +@inherit_doc |
| 28 | +class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed): |
| 29 | + """ |
| 30 | + Alternating Least Squares (ALS) matrix factorization. |
| 31 | +
|
| 32 | + ALS attempts to estimate the ratings matrix `R` as the product of |
| 33 | + two lower-rank matrices, `X` and `Y`, i.e. `X * Yt = R`. Typically |
| 34 | + these approximations are called 'factor' matrices. The general |
| 35 | + approach is iterative. During each iteration, one of the factor |
| 36 | + matrices is held constant, while the other is solved for using least |
| 37 | + squares. The newly-solved factor matrix is then held constant while |
| 38 | + solving for the other factor matrix. |
| 39 | +
|
| 40 | + This is a blocked implementation of the ALS factorization algorithm |
| 41 | + that groups the two sets of factors (referred to as "users" and |
| 42 | + "products") into blocks and reduces communication by only sending |
| 43 | + one copy of each user vector to each product block on each |
| 44 | + iteration, and only for the product blocks that need that user's |
| 45 | + feature vector. This is achieved by pre-computing some information |
| 46 | + about the ratings matrix to determine the "out-links" of each user |
| 47 | + (which blocks of products it will contribute to) and "in-link" |
| 48 | + information for each product (which of the feature vectors it |
| 49 | + receives from each user block it will depend on). This allows us to |
| 50 | + send only an array of feature vectors between each user block and |
| 51 | + product block, and have the product block find the users' ratings |
| 52 | + and update the products based on these messages. |
| 53 | +
|
| 54 | + For implicit preference data, the algorithm used is based on |
| 55 | + "Collaborative Filtering for Implicit Feedback Datasets", available |
| 56 | + at `http://dx.doi.org/10.1109/ICDM.2008.22`, adapted for the blocked |
| 57 | + approach used here. |
| 58 | +
|
| 59 | + Essentially instead of finding the low-rank approximations to the |
| 60 | + rating matrix `R`, this finds the approximations for a preference |
| 61 | + matrix `P` where the elements of `P` are 1 if r > 0 and 0 if r <= 0. |
| 62 | + The ratings then act as 'confidence' values related to strength of |
| 63 | + indicated user preferences rather than explicit ratings given to |
| 64 | + items. |
| 65 | +
|
| 66 | + >>> als = ALS(rank=10, maxIter=5) |
| 67 | + >>> model = als.fit(df) |
| 68 | + >>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"]) |
| 69 | + >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0]) |
| 70 | + >>> predictions[0] |
| 71 | + Row(user=0, item=2, prediction=0.39...) |
| 72 | + >>> predictions[1] |
| 73 | + Row(user=1, item=0, prediction=3.19...) |
| 74 | + >>> predictions[2] |
| 75 | + Row(user=2, item=0, prediction=-1.15...) |
| 76 | + """ |
| 77 | + _java_class = "org.apache.spark.ml.recommendation.ALS" |
| 78 | + # a placeholder to make it appear in the generated doc |
| 79 | + rank = Param(Params._dummy(), "rank", "rank of the factorization") |
| 80 | + numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks") |
| 81 | + numItemBlocks = Param(Params._dummy(), "numItemBlocks", "number of item blocks") |
| 82 | + implicitPrefs = Param(Params._dummy(), "implicitPrefs", "whether to use implicit preference") |
| 83 | + alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference") |
| 84 | + userCol = Param(Params._dummy(), "userCol", "column name for user ids") |
| 85 | + itemCol = Param(Params._dummy(), "itemCol", "column name for item ids") |
| 86 | + ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings") |
| 87 | + nonnegative = Param(Params._dummy(), "nonnegative", |
| 88 | + "whether to use nonnegative constraint for least squares") |
| 89 | + |
| 90 | + @keyword_only |
| 91 | + def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, |
| 92 | + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0, |
| 93 | + ratingCol="rating", nonnegative=False, checkpointInterval=10): |
| 94 | + """ |
| 95 | + __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, |
| 96 | + implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=0, |
| 97 | + ratingCol="rating", nonnegative=false, checkpointInterval=10) |
| 98 | + """ |
| 99 | + super(ALS, self).__init__() |
| 100 | + self.rank = Param(self, "rank", "rank of the factorization") |
| 101 | + self.numUserBlocks = Param(self, "numUserBlocks", "number of user blocks") |
| 102 | + self.numItemBlocks = Param(self, "numItemBlocks", "number of item blocks") |
| 103 | + self.implicitPrefs = Param(self, "implicitPrefs", "whether to use implicit preference") |
| 104 | + self.alpha = Param(self, "alpha", "alpha for implicit preference") |
| 105 | + self.userCol = Param(self, "userCol", "column name for user ids") |
| 106 | + self.itemCol = Param(self, "itemCol", "column name for item ids") |
| 107 | + self.ratingCol = Param(self, "ratingCol", "column name for ratings") |
| 108 | + self.nonnegative = Param(self, "nonnegative", |
| 109 | + "whether to use nonnegative constraint for least squares") |
| 110 | + self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, |
| 111 | + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0, |
| 112 | + ratingCol="rating", nonnegative=False, checkpointInterval=10) |
| 113 | + kwargs = self.__init__._input_kwargs |
| 114 | + self.setParams(**kwargs) |
| 115 | + |
| 116 | + @keyword_only |
| 117 | + def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, |
| 118 | + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0, |
| 119 | + ratingCol="rating", nonnegative=False, checkpointInterval=10): |
| 120 | + """ |
| 121 | + setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, |
| 122 | + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0, |
| 123 | + ratingCol="rating", nonnegative=False, checkpointInterval=10) |
| 124 | + Sets params for ALS. |
| 125 | + """ |
| 126 | + kwargs = self.setParams._input_kwargs |
| 127 | + return self._set(**kwargs) |
| 128 | + |
| 129 | + def _create_model(self, java_model): |
| 130 | + return ALSModel(java_model) |
| 131 | + |
| 132 | + def setRank(self, value): |
| 133 | + """ |
| 134 | + Sets the value of :py:attr:`rank`. |
| 135 | + """ |
| 136 | + self.paramMap[self.rank] = value |
| 137 | + return self |
| 138 | + |
| 139 | + def getRank(self): |
| 140 | + """ |
| 141 | + Gets the value of rank or its default value. |
| 142 | + """ |
| 143 | + return self.getOrDefault(self.rank) |
| 144 | + |
| 145 | + def setNumUserBlocks(self, value): |
| 146 | + """ |
| 147 | + Sets the value of :py:attr:`numUserBlocks`. |
| 148 | + """ |
| 149 | + self.paramMap[self.numUserBlocks] = value |
| 150 | + return self |
| 151 | + |
| 152 | + def getNumUserBlocks(self): |
| 153 | + """ |
| 154 | + Gets the value of numUserBlocks or its default value. |
| 155 | + """ |
| 156 | + return self.getOrDefault(self.numUserBlocks) |
| 157 | + |
| 158 | + def setNumItemBlocks(self, value): |
| 159 | + """ |
| 160 | + Sets the value of :py:attr:`numItemBlocks`. |
| 161 | + """ |
| 162 | + self.paramMap[self.numItemBlocks] = value |
| 163 | + return self |
| 164 | + |
| 165 | + def getNumItemBlocks(self): |
| 166 | + """ |
| 167 | + Gets the value of numItemBlocks or its default value. |
| 168 | + """ |
| 169 | + return self.getOrDefault(self.numItemBlocks) |
| 170 | + |
| 171 | + def setNumBlocks(self, value): |
| 172 | + """ |
| 173 | + Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value. |
| 174 | + """ |
| 175 | + self.paramMap[self.numUserBlocks] = value |
| 176 | + self.paramMap[self.numItemBlocks] = value |
| 177 | + |
| 178 | + def setImplicitPrefs(self, value): |
| 179 | + """ |
| 180 | + Sets the value of :py:attr:`implicitPrefs`. |
| 181 | + """ |
| 182 | + self.paramMap[self.implicitPrefs] = value |
| 183 | + return self |
| 184 | + |
| 185 | + def getImplicitPrefs(self): |
| 186 | + """ |
| 187 | + Gets the value of implicitPrefs or its default value. |
| 188 | + """ |
| 189 | + return self.getOrDefault(self.implicitPrefs) |
| 190 | + |
| 191 | + def setAlpha(self, value): |
| 192 | + """ |
| 193 | + Sets the value of :py:attr:`alpha`. |
| 194 | + """ |
| 195 | + self.paramMap[self.alpha] = value |
| 196 | + return self |
| 197 | + |
| 198 | + def getAlpha(self): |
| 199 | + """ |
| 200 | + Gets the value of alpha or its default value. |
| 201 | + """ |
| 202 | + return self.getOrDefault(self.alpha) |
| 203 | + |
| 204 | + def setUserCol(self, value): |
| 205 | + """ |
| 206 | + Sets the value of :py:attr:`userCol`. |
| 207 | + """ |
| 208 | + self.paramMap[self.userCol] = value |
| 209 | + return self |
| 210 | + |
| 211 | + def getUserCol(self): |
| 212 | + """ |
| 213 | + Gets the value of userCol or its default value. |
| 214 | + """ |
| 215 | + return self.getOrDefault(self.userCol) |
| 216 | + |
| 217 | + def setItemCol(self, value): |
| 218 | + """ |
| 219 | + Sets the value of :py:attr:`itemCol`. |
| 220 | + """ |
| 221 | + self.paramMap[self.itemCol] = value |
| 222 | + return self |
| 223 | + |
| 224 | + def getItemCol(self): |
| 225 | + """ |
| 226 | + Gets the value of itemCol or its default value. |
| 227 | + """ |
| 228 | + return self.getOrDefault(self.itemCol) |
| 229 | + |
| 230 | + def setRatingCol(self, value): |
| 231 | + """ |
| 232 | + Sets the value of :py:attr:`ratingCol`. |
| 233 | + """ |
| 234 | + self.paramMap[self.ratingCol] = value |
| 235 | + return self |
| 236 | + |
| 237 | + def getRatingCol(self): |
| 238 | + """ |
| 239 | + Gets the value of ratingCol or its default value. |
| 240 | + """ |
| 241 | + return self.getOrDefault(self.ratingCol) |
| 242 | + |
| 243 | + def setNonnegative(self, value): |
| 244 | + """ |
| 245 | + Sets the value of :py:attr:`nonnegative`. |
| 246 | + """ |
| 247 | + self.paramMap[self.nonnegative] = value |
| 248 | + return self |
| 249 | + |
| 250 | + def getNonnegative(self): |
| 251 | + """ |
| 252 | + Gets the value of nonnegative or its default value. |
| 253 | + """ |
| 254 | + return self.getOrDefault(self.nonnegative) |
| 255 | + |
| 256 | + |
| 257 | +class ALSModel(JavaModel): |
| 258 | + """ |
| 259 | + Model fitted by ALS. |
| 260 | + """ |
| 261 | + |
| 262 | + |
| 263 | +if __name__ == "__main__": |
| 264 | + import doctest |
| 265 | + from pyspark.context import SparkContext |
| 266 | + from pyspark.sql import SQLContext |
| 267 | + globs = globals().copy() |
| 268 | + # The small batch size here ensures that we see multiple batches, |
| 269 | + # even in these small test examples: |
| 270 | + sc = SparkContext("local[2]", "ml.recommendation tests") |
| 271 | + sqlContext = SQLContext(sc) |
| 272 | + globs['sc'] = sc |
| 273 | + globs['sqlContext'] = sqlContext |
| 274 | + globs['df'] = sqlContext.createDataFrame([(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), |
| 275 | + (2, 1, 1.0), (2, 2, 5.0)], ["user", "item", "rating"]) |
| 276 | + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) |
| 277 | + sc.stop() |
| 278 | + if failure_count: |
| 279 | + exit(-1) |
0 commit comments