Skip to content

Commit 5ad78f6

Browse files
committed
[SQL] Various DataFrame DSL update.
1. Added foreach, foreachPartition, flatMap to DataFrame. 2. Added col() in dsl. 3. Support renaming columns in toDataFrame. 4. Support type inference on arrays (in addition to Seq). 5. Updated mllib to use the new DSL. Author: Reynold Xin <[email protected]> Closes #4260 from rxin/sql-dsl-update and squashes the following commits: 73466c1 [Reynold Xin] Fixed LogisticRegression. Also added better error message for resolve. fab3ccc [Reynold Xin] Bug fix. d31fcd2 [Reynold Xin] Style fix. 62608c4 [Reynold Xin] [SQL] Various DataFrame DSL update.
1 parent a63be1a commit 5ad78f6

File tree

12 files changed

+114
-45
lines changed

12 files changed

+114
-45
lines changed

mllib/src/main/scala/org/apache/spark/ml/Transformer.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import org.apache.spark.Logging
2323
import org.apache.spark.annotation.AlphaComponent
2424
import org.apache.spark.ml.param._
2525
import org.apache.spark.sql.DataFrame
26-
import org.apache.spark.sql._
2726
import org.apache.spark.sql.api.scala.dsl._
2827
import org.apache.spark.sql.types._
2928

@@ -99,6 +98,6 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
9998
transformSchema(dataset.schema, paramMap, logging = true)
10099
val map = this.paramMap ++ paramMap
101100
dataset.select($"*", callUDF(
102-
this.createTransformFunc(map), outputDataType, Column(map(inputCol))).as(map(outputCol)))
101+
this.createTransformFunc(map), outputDataType, dataset(map(inputCol))).as(map(outputCol)))
103102
}
104103
}

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
2525
import org.apache.spark.mllib.regression.LabeledPoint
2626
import org.apache.spark.sql._
2727
import org.apache.spark.sql.api.scala.dsl._
28-
import org.apache.spark.sql.catalyst.dsl._
2928
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
3029
import org.apache.spark.storage.StorageLevel
3130

@@ -133,15 +132,14 @@ class LogisticRegressionModel private[ml] (
133132
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
134133
transformSchema(dataset.schema, paramMap, logging = true)
135134
val map = this.paramMap ++ paramMap
136-
val score: Vector => Double = (v) => {
135+
val scoreFunction: Vector => Double = (v) => {
137136
val margin = BLAS.dot(v, weights)
138137
1.0 / (1.0 + math.exp(-margin))
139138
}
140139
val t = map(threshold)
141-
val predict: Double => Double = (score) => {
142-
if (score > t) 1.0 else 0.0
143-
}
144-
dataset.select($"*", callUDF(score, Column(map(featuresCol))).as(map(scoreCol)))
145-
.select($"*", callUDF(predict, Column(map(scoreCol))).as(map(predictionCol)))
140+
val predictFunction: Double => Double = (score) => { if (score > t) 1.0 else 0.0 }
141+
dataset
142+
.select($"*", callUDF(scoreFunction, col(map(featuresCol))).as(map(scoreCol)))
143+
.select($"*", callUDF(predictFunction, col(map(scoreCol))).as(map(predictionCol)))
146144
}
147145
}

mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import org.apache.spark.mllib.feature
2424
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
2525
import org.apache.spark.sql._
2626
import org.apache.spark.sql.api.scala.dsl._
27-
import org.apache.spark.sql.catalyst.dsl._
2827
import org.apache.spark.sql.types.{StructField, StructType}
2928

3029
/**
@@ -85,7 +84,7 @@ class StandardScalerModel private[ml] (
8584
val scale: (Vector) => Vector = (v) => {
8685
scaler.transform(v)
8786
}
88-
dataset.select($"*", callUDF(scale, Column(map(inputCol))).as(map(outputCol)))
87+
dataset.select($"*", callUDF(scale, col(map(inputCol))).as(map(outputCol)))
8988
}
9089

9190
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,10 @@ class ALSModel private[ml] (
111111
def setPredictionCol(value: String): this.type = set(predictionCol, value)
112112

113113
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
114-
import dataset.sqlContext._
115-
import org.apache.spark.ml.recommendation.ALSModel.Factor
114+
import dataset.sqlContext.createDataFrame
116115
val map = this.paramMap ++ paramMap
117-
// TODO: Add DSL to simplify the code here.
118-
val instanceTable = s"instance_$uid"
119-
val userTable = s"user_$uid"
120-
val itemTable = s"item_$uid"
121-
val instances = dataset.as(instanceTable)
122-
val users = userFactors.map { case (id, features) =>
123-
Factor(id, features)
124-
}.as(userTable)
125-
val items = itemFactors.map { case (id, features) =>
126-
Factor(id, features)
127-
}.as(itemTable)
116+
val users = userFactors.toDataFrame("id", "features")
117+
val items = itemFactors.toDataFrame("id", "features")
128118
val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => {
129119
if (userFeatures != null && itemFeatures != null) {
130120
blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
@@ -133,24 +123,21 @@ class ALSModel private[ml] (
133123
}
134124
}
135125
val inputColumns = dataset.schema.fieldNames
136-
val prediction = callUDF(predict, $"$userTable.features", $"$itemTable.features")
137-
.as(map(predictionCol))
138-
val outputColumns = inputColumns.map(f => $"$instanceTable.$f".as(f)) :+ prediction
139-
instances
140-
.join(users, Column(map(userCol)) === $"$userTable.id", "left")
141-
.join(items, Column(map(itemCol)) === $"$itemTable.id", "left")
126+
val prediction = callUDF(predict, users("features"), items("features")).as(map(predictionCol))
127+
val outputColumns = inputColumns.map(f => dataset(f)) :+ prediction
128+
dataset
129+
.join(users, dataset(map(userCol)) === users("id"), "left")
130+
.join(items, dataset(map(itemCol)) === items("id"), "left")
142131
.select(outputColumns: _*)
132+
// TODO: Just use a dataset("*")
133+
// .select(dataset("*"), prediction)
143134
}
144135

145136
override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
146137
validateAndTransformSchema(schema, paramMap)
147138
}
148139
}
149140

150-
private object ALSModel {
151-
/** Case class to convert factors to [[DataFrame]]s */
152-
private case class Factor(id: Int, features: Seq[Float])
153-
}
154141

155142
/**
156143
* Alternating Least Squares (ALS) matrix factorization.
@@ -210,7 +197,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
210197
override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
211198
val map = this.paramMap ++ paramMap
212199
val ratings = dataset
213-
.select(Column(map(userCol)), Column(map(itemCol)), Column(map(ratingCol)).cast(FloatType))
200+
.select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType))
214201
.map { row =>
215202
new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
216203
}

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
2727

2828
import org.apache.spark.SparkException
2929
import org.apache.spark.mllib.util.NumericParser
30-
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row}
30+
import org.apache.spark.sql.Row
31+
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
3132
import org.apache.spark.sql.types._
3233

3334
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ trait ScalaReflection {
5757
case (obj, udt: UserDefinedType[_]) => udt.serialize(obj)
5858
case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull
5959
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType))
60+
case (s: Array[_], arrayType: ArrayType) => s.toSeq
6061
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
6162
convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
6263
}
@@ -140,7 +141,9 @@ trait ScalaReflection {
140141
// Need to decide if we actually need a special type here.
141142
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
142143
case t if t <:< typeOf[Array[_]] =>
143-
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
144+
val TypeRef(_, _, Seq(elementType)) = t
145+
val Schema(dataType, nullable) = schemaFor(elementType)
146+
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
144147
case t if t <:< typeOf[Seq[_]] =>
145148
val TypeRef(_, _, Seq(elementType)) = t
146149
val Schema(dataType, nullable) = schemaFor(elementType)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ case class OptionalData(
6060

6161
case class ComplexData(
6262
arrayField: Seq[Int],
63+
arrayField1: Array[Int],
6364
arrayFieldContainsNull: Seq[java.lang.Integer],
6465
mapField: Map[Int, Long],
6566
mapFieldValueContainsNull: Map[Int, java.lang.Long],
@@ -131,6 +132,10 @@ class ScalaReflectionSuite extends FunSuite {
131132
"arrayField",
132133
ArrayType(IntegerType, containsNull = false),
133134
nullable = true),
135+
StructField(
136+
"arrayField1",
137+
ArrayType(IntegerType, containsNull = false),
138+
nullable = true),
134139
StructField(
135140
"arrayFieldContainsNull",
136141
ArrayType(IntegerType, containsNull = true),

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,19 @@ import scala.language.implicitConversions
2222
import org.apache.spark.sql.api.scala.dsl.lit
2323
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star}
2424
import org.apache.spark.sql.catalyst.expressions._
25-
import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
2625
import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
2726
import org.apache.spark.sql.types._
2827

2928

3029
object Column {
31-
def unapply(col: Column): Option[Expression] = Some(col.expr)
32-
30+
/**
31+
* Creates a [[Column]] based on the given column name.
32+
* Same as [[api.scala.dsl.col]] and [[api.java.dsl.col]].
33+
*/
3334
def apply(colName: String): Column = new Column(colName)
35+
36+
/** For internal pattern matching. */
37+
private[sql] def unapply(col: Column): Option[Expression] = Some(col.expr)
3438
}
3539

3640

@@ -438,7 +442,7 @@ class Column(
438442
* @param ordinal
439443
* @return
440444
*/
441-
override def getItem(ordinal: Int): Column = GetItem(expr, LiteralExpr(ordinal))
445+
override def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal))
442446

443447
/**
444448
* An expression that gets a field by name in a [[StructField]].

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ class DataFrame protected[sql](
118118

119119
/** Resolves a column name into a Catalyst [[NamedExpression]]. */
120120
protected[sql] def resolve(colName: String): NamedExpression = {
121-
logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse(
122-
throw new RuntimeException(s"""Cannot resolve column name "$colName""""))
121+
logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse(throw new RuntimeException(
122+
s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})"""))
123123
}
124124

125125
/** Left here for compatibility reasons. */
@@ -131,6 +131,29 @@ class DataFrame protected[sql](
131131
*/
132132
def toDataFrame: DataFrame = this
133133

134+
/**
135+
* Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
136+
* from a RDD of tuples into a [[DataFrame]] with meaningful names. For example:
137+
* {{{
138+
* val rdd: RDD[(Int, String)] = ...
139+
* rdd.toDataFrame // this implicit conversion creates a DataFrame with column name _1 and _2
140+
* rdd.toDataFrame("id", "name") // this creates a DataFrame with column name "id" and "name"
141+
* }}}
142+
*/
143+
@scala.annotation.varargs
144+
def toDataFrame(colName: String, colNames: String*): DataFrame = {
145+
val newNames = colName +: colNames
146+
require(schema.size == newNames.size,
147+
"The number of columns doesn't match.\n" +
148+
"Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" +
149+
"New column names: " + newNames.mkString(", "))
150+
151+
val newCols = schema.fieldNames.zip(newNames).map { case (oldName, newName) =>
152+
apply(oldName).as(newName)
153+
}
154+
select(newCols :_*)
155+
}
156+
134157
/** Returns the schema of this [[DataFrame]]. */
135158
override def schema: StructType = queryExecution.analyzed.schema
136159

@@ -227,7 +250,7 @@ class DataFrame protected[sql](
227250
}
228251

229252
/**
230-
* Selects a single column and return it as a [[Column]].
253+
* Selects column based on the column name and return it as a [[Column]].
231254
*/
232255
override def apply(colName: String): Column = colName match {
233256
case "*" =>
@@ -466,13 +489,29 @@ class DataFrame protected[sql](
466489
rdd.map(f)
467490
}
468491

492+
/**
493+
* Returns a new RDD by first applying a function to all rows of this [[DataFrame]],
494+
* and then flattening the results.
495+
*/
496+
override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f)
497+
469498
/**
470499
* Returns a new RDD by applying a function to each partition of this DataFrame.
471500
*/
472501
override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = {
473502
rdd.mapPartitions(f)
474503
}
475504

505+
/**
506+
* Applies a function `f` to all rows.
507+
*/
508+
override def foreach(f: Row => Unit): Unit = rdd.foreach(f)
509+
510+
/**
511+
* Applies a function f to each partition of this [[DataFrame]].
512+
*/
513+
override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f)
514+
476515
/**
477516
* Returns the first `n` rows in the [[DataFrame]].
478517
*/
@@ -520,7 +559,7 @@ class DataFrame protected[sql](
520559
/////////////////////////////////////////////////////////////////////////////
521560

522561
/**
523-
* Return the content of the [[DataFrame]] as a [[RDD]] of [[Row]]s.
562+
* Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s.
524563
*/
525564
override def rdd: RDD[Row] = {
526565
val schema = this.schema

sql/core/src/main/scala/org/apache/spark/sql/api.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,14 @@ private[sql] trait RDDApi[T] {
4444

4545
def map[R: ClassTag](f: T => R): RDD[R]
4646

47+
def flatMap[R: ClassTag](f: T => TraversableOnce[R]): RDD[R]
48+
4749
def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R]
4850

51+
def foreach(f: T => Unit): Unit
52+
53+
def foreachPartition(f: Iterator[T] => Unit): Unit
54+
4955
def take(n: Int): Array[T]
5056

5157
def collect(): Array[T]

0 commit comments

Comments
 (0)