@@ -30,7 +30,6 @@ import org.apache.spark.ml.{Estimator, Model}
3030import org .apache .spark .ml .param ._
3131import org .apache .spark .rdd .RDD
3232import org .apache .spark .sql .{SchemaRDD , StructType }
33- import org .apache .spark .sql .catalyst .analysis .Star
3433import org .apache .spark .sql .catalyst .dsl ._
3534import org .apache .spark .sql .catalyst .expressions .Cast
3635import org .apache .spark .sql .catalyst .plans .LeftOuter
@@ -44,30 +43,38 @@ import org.apache.spark.util.random.XORShiftRandom
4443private [recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
4544 with HasPredictionCol {
4645
46+ /** Param for rank of the matrix factorization. */
4747 val rank = new IntParam (this , " rank" , " rank of the factorization" , Some (10 ))
4848 def getRank : Int = get(rank)
4949
50+ /** Param for number of user blocks. */
5051 val numUserBlocks = new IntParam (this , " numUserBlocks" , " number of user blocks" , Some (10 ))
5152 def getNumUserBlocks : Int = get(numUserBlocks)
5253
54+ /** Param for number of product blocks. */
5355 val numProductBlocks =
5456 new IntParam (this , " numProductBlocks" , " number of product blocks" , Some (10 ))
5557 def getNumProductBlocks : Int = get(numProductBlocks)
5658
59+ /** Param to decide whether to use implicit preference. */
5760 val implicitPrefs =
5861 new BooleanParam (this , " implicitPrefs" , " whether to use implicit preference" , Some (false ))
5962 def getImplicitPrefs : Boolean = get(implicitPrefs)
6063
64+ /** Param for the alpha parameter in the implicit preference formulation. */
6165 val alpha = new DoubleParam (this , " alpha" , " alpha for implicit preference" , Some (1.0 ))
6266 def getAlpha : Double = get(alpha)
6367
68+ /** Param for the column name for user ids. */
6469 val userCol = new Param [String ](this , " userCol" , " column name for user ids" , Some (" user" ))
6570 def getUserCol : String = get(userCol)
6671
72+ /** Param for the column name for product ids. */
6773 val productCol =
6874 new Param [String ](this , " productCol" , " column name for product ids" , Some (" product" ))
6975 def getProductCol : String = get(productCol)
7076
77+ /** Param for the column name for ratings. */
7178 val ratingCol = new Param [String ](this , " ratingCol" , " column name for ratings" , Some (" rating" ))
7279 def getRatingCol : String = get(ratingCol)
7380
@@ -108,8 +115,11 @@ class ALSModel private[ml] (
108115 import dataset .sqlContext ._
109116 import org .apache .spark .ml .recommendation .ALSModel .Factor
110117 val map = this .paramMap ++ paramMap
111- val userTable = s " user- $uid"
112- val prodTable = s " prod- $uid"
118+ // TODO: Add DSL to simplify the code here.
119+ val instanceTable = s " instance_ $uid"
120+ val userTable = s " user_ $uid"
121+ val prodTable = s " prod_ $uid"
122+ val instances = dataset.as(Symbol (instanceTable))
113123 val users = userFactors.map { case (id, features) =>
114124 Factor (id, features)
115125 }.as(Symbol (userTable))
@@ -123,11 +133,14 @@ class ALSModel private[ml] (
123133 Float .NaN
124134 }
125135 }
126- dataset.join(users, LeftOuter , Some (map(userCol).attr === s " $userTable.id " .attr))
136+ val inputColumns = dataset.schema.fieldNames
137+ val prediction =
138+ predict.call(s " $userTable.features " .attr, s " $prodTable.features " .attr) as map (predictionCol)
139+ val outputColumns = inputColumns.map(f => s " $instanceTable. $f" .attr as f ) :+ prediction
140+ instances
141+ .join(users, LeftOuter , Some (map(userCol).attr === s " $userTable.id " .attr))
127142 .join(prods, LeftOuter , Some (map(productCol).attr === s " $prodTable.id " .attr))
128- .select(Star (None ),
129- predict.call(s " $userTable.features " .attr, s " $prodTable.features " .attr)
130- as map (predictionCol))
143+ .select(outputColumns : _* )
131144 }
132145
133146 override private [ml] def transformSchema (schema : StructType , paramMap : ParamMap ): StructType = {
0 commit comments