From 9dec76fbd37376970bbb6a3f894ddb9cc48a8f43 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 17 Nov 2015 17:59:27 -0800 Subject: [PATCH 1/2] [SPARK-11797][SQL] collect, first, and take should use encoders for serialization. --- .../scala/org/apache/spark/sql/Dataset.scala | 14 ++++++++--- .../org/apache/spark/sql/DatasetSuite.scala | 25 ++++++++++++++++++- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index bd01dd4dc579..33531d635dc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -22,6 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.api.java.function._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ @@ -519,7 +520,7 @@ class Dataset[T] private[sql]( * Returns the first element in this [[Dataset]]. * @since 1.6.0 */ - def first(): T = rdd.first() + def first(): T = take(1).head /** * Returns an array that contains all the elements in this [[Dataset]]. @@ -530,7 +531,12 @@ class Dataset[T] private[sql]( * For Java API, use [[collectAsList]]. * @since 1.6.0 */ - def collect(): Array[T] = rdd.collect() + def collect(): Array[T] = { + val tEnc = resolvedTEncoder + val input = queryExecution.analyzed.output + val bound = tEnc.bind(input) + queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow) + } /** * Returns an array that contains all the elements in this [[Dataset]]. @@ -541,7 +547,7 @@ class Dataset[T] private[sql]( * For Java API, use [[collectAsList]]. * @since 1.6.0 */ - def collectAsList(): java.util.List[T] = rdd.collect().toSeq.asJava + def collectAsList(): java.util.List[T] = collect().toSeq.asJava /** * Returns the first `num` elements of this [[Dataset]] as an array. @@ -551,7 +557,7 @@ class Dataset[T] private[sql]( * * @since 1.6.0 */ - def take(num: Int): Array[T] = rdd.take(num) + def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect() /** * Returns the first `num` elements of this [[Dataset]] as an array. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index a3922340ccc9..2ec9fddad080 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.io.{ObjectInput, ObjectOutput, Externalizable} + import scala.language.postfixOps import org.apache.spark.sql.functions._ @@ -24,6 +26,17 @@ import org.apache.spark.sql.test.SharedSQLContext case class ClassData(a: String, b: Int) +/** A class used to test serialization using encoders. */ +case class NonSerializableCaseClass(value: String) extends Externalizable { + override def readExternal(in: ObjectInput): Unit = { + throw new UnsupportedOperationException + } + + override def writeExternal(out: ObjectOutput): Unit = { + throw new UnsupportedOperationException + } +} + class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -41,6 +54,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1, 1, 1) } + test("collect, first, and take should use encoders for serialization") { + val item = NonSerializableCaseClass("abcd") + val ds = Seq(item).toDS() + assert(ds.collect().head == item) + assert(ds.collectAsList().get(0) == item) + assert(ds.first() == item) + assert(ds.take(1).head == item) + assert(ds.takeAsList(1).get(0) == item) + } + test("as tuple") { val data = Seq(("a", 1), ("b", 2)).toDF("a", "b") checkAnswer( @@ -219,7 +242,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy function, fatMap") { + test("groupBy function, flatMap") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) } From 9c3bd921a00c15a71b18468793b6dfc81f3d52c0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 17 Nov 2015 19:09:12 -0800 Subject: [PATCH 2/2] Address comments. --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 3 ++- .../src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 33531d635dc7..718ed812dd64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -200,7 +200,6 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { - encoderFor[T].assertUnresolved() new Dataset[U]( sqlContext, MapPartitions[T, U]( @@ -532,6 +531,8 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def collect(): Array[T] = { + // This is different from Dataset.rdd in that it collects Rows, and then runs the encoders + // to convert the rows into objects of type T. val tEnc = resolvedTEncoder val input = queryExecution.analyzed.output val bound = tEnc.bind(input) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 2ec9fddad080..ea29428c5508 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -26,7 +26,10 @@ import org.apache.spark.sql.test.SharedSQLContext case class ClassData(a: String, b: Int) -/** A class used to test serialization using encoders. */ +/** + * A class used to test serialization using encoders. This class throws exceptions when using + * Java serialization -- so the only way it can be "serialized" is through our encoders. + */ case class NonSerializableCaseClass(value: String) extends Externalizable { override def readExternal(in: ObjectInput): Unit = { throw new UnsupportedOperationException @@ -98,6 +101,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ignore("Dataset should set the resolved encoders internally for maps") { // TODO: Enable this once we fix SPARK-11793. + // We inject a group by here to make sure this test case is future proof + // when we implement better pipelining and local execution mode. val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS() .map(c => ClassData(c.a, c.b + 1)) .groupBy(p => p).count()