Skip to content

Commit bce8692

Browse files
committed
doc for parameters and project the output columns
1 parent 3f2d81a commit bce8692

File tree

1 file changed

+20
-7
lines changed
  • mllib/src/main/scala/org/apache/spark/ml/recommendation

1 file changed

+20
-7
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ import org.apache.spark.ml.{Estimator, Model}
3030
import org.apache.spark.ml.param._
3131
import org.apache.spark.rdd.RDD
3232
import org.apache.spark.sql.{SchemaRDD, StructType}
33-
import org.apache.spark.sql.catalyst.analysis.Star
3433
import org.apache.spark.sql.catalyst.dsl._
3534
import org.apache.spark.sql.catalyst.expressions.Cast
3635
import org.apache.spark.sql.catalyst.plans.LeftOuter
@@ -44,30 +43,38 @@ import org.apache.spark.util.random.XORShiftRandom
4443
private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
4544
with HasPredictionCol {
4645

46+
/** Param for rank of the matrix factorization. */
4747
val rank = new IntParam(this, "rank", "rank of the factorization", Some(10))
4848
def getRank: Int = get(rank)
4949

50+
/** Param for number of user blocks. */
5051
val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10))
5152
def getNumUserBlocks: Int = get(numUserBlocks)
5253

54+
/** Param for number of product blocks. */
5355
val numProductBlocks =
5456
new IntParam(this, "numProductBlocks", "number of product blocks", Some(10))
5557
def getNumProductBlocks: Int = get(numProductBlocks)
5658

59+
/** Param to decide whether to use implicit preference. */
5760
val implicitPrefs =
5861
new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false))
5962
def getImplicitPrefs: Boolean = get(implicitPrefs)
6063

64+
/** Param for the alpha parameter in the implicit preference formulation. */
6165
val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0))
6266
def getAlpha: Double = get(alpha)
6367

68+
/** Param for the column name for user ids. */
6469
val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user"))
6570
def getUserCol: String = get(userCol)
6671

72+
/** Param for the column name for product ids. */
6773
val productCol =
6874
new Param[String](this, "productCol", "column name for product ids", Some("product"))
6975
def getProductCol: String = get(productCol)
7076

77+
/** Param for the column name for ratings. */
7178
val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating"))
7279
def getRatingCol: String = get(ratingCol)
7380

@@ -108,8 +115,11 @@ class ALSModel private[ml] (
108115
import dataset.sqlContext._
109116
import org.apache.spark.ml.recommendation.ALSModel.Factor
110117
val map = this.paramMap ++ paramMap
111-
val userTable = s"user-$uid"
112-
val prodTable = s"prod-$uid"
118+
// TODO: Add DSL to simplify the code here.
119+
val instanceTable = s"instance_$uid"
120+
val userTable = s"user_$uid"
121+
val prodTable = s"prod_$uid"
122+
val instances = dataset.as(Symbol(instanceTable))
113123
val users = userFactors.map { case (id, features) =>
114124
Factor(id, features)
115125
}.as(Symbol(userTable))
@@ -123,11 +133,14 @@ class ALSModel private[ml] (
123133
Float.NaN
124134
}
125135
}
126-
dataset.join(users, LeftOuter, Some(map(userCol).attr === s"$userTable.id".attr))
136+
val inputColumns = dataset.schema.fieldNames
137+
val prediction =
138+
predict.call(s"$userTable.features".attr, s"$prodTable.features".attr) as map(predictionCol)
139+
val outputColumns = inputColumns.map(f => s"$instanceTable.$f".attr as f) :+ prediction
140+
instances
141+
.join(users, LeftOuter, Some(map(userCol).attr === s"$userTable.id".attr))
127142
.join(prods, LeftOuter, Some(map(productCol).attr === s"$prodTable.id".attr))
128-
.select(Star(None),
129-
predict.call(s"$userTable.features".attr, s"$prodTable.features".attr)
130-
as map(predictionCol))
143+
.select(outputColumns: _*)
131144
}
132145

133146
override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {

0 commit comments

Comments
 (0)