diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index bd01dd4dc579..718ed812dd64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -22,6 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.api.java.function._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ @@ -199,7 +200,6 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { - encoderFor[T].assertUnresolved() new Dataset[U]( sqlContext, MapPartitions[T, U]( @@ -519,7 +519,7 @@ class Dataset[T] private[sql]( * Returns the first element in this [[Dataset]]. * @since 1.6.0 */ - def first(): T = rdd.first() + def first(): T = take(1).head /** * Returns an array that contains all the elements in this [[Dataset]]. @@ -530,7 +530,14 @@ class Dataset[T] private[sql]( * For Java API, use [[collectAsList]]. * @since 1.6.0 */ - def collect(): Array[T] = rdd.collect() + def collect(): Array[T] = { + // This is different from Dataset.rdd in that it collects Rows, and then runs the encoders + // to convert the rows into objects of type T. + val tEnc = resolvedTEncoder + val input = queryExecution.analyzed.output + val bound = tEnc.bind(input) + queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow) + } /** * Returns an array that contains all the elements in this [[Dataset]]. @@ -541,7 +548,7 @@ class Dataset[T] private[sql]( * For Java API, use [[collectAsList]]. * @since 1.6.0 */ - def collectAsList(): java.util.List[T] = rdd.collect().toSeq.asJava + def collectAsList(): java.util.List[T] = collect().toSeq.asJava /** * Returns the first `num` elements of this [[Dataset]] as an array. @@ -551,7 +558,7 @@ class Dataset[T] private[sql]( * * @since 1.6.0 */ - def take(num: Int): Array[T] = rdd.take(num) + def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect() /** * Returns the first `num` elements of this [[Dataset]] as an array. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index a3922340ccc9..ea29428c5508 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.io.{ObjectInput, ObjectOutput, Externalizable} + import scala.language.postfixOps import org.apache.spark.sql.functions._ @@ -24,6 +26,20 @@ import org.apache.spark.sql.test.SharedSQLContext case class ClassData(a: String, b: Int) +/** + * A class used to test serialization using encoders. This class throws exceptions when using + * Java serialization -- so the only way it can be "serialized" is through our encoders. + */ +case class NonSerializableCaseClass(value: String) extends Externalizable { + override def readExternal(in: ObjectInput): Unit = { + throw new UnsupportedOperationException + } + + override def writeExternal(out: ObjectOutput): Unit = { + throw new UnsupportedOperationException + } +} + class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -41,6 +57,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1, 1, 1) } + test("collect, first, and take should use encoders for serialization") { + val item = NonSerializableCaseClass("abcd") + val ds = Seq(item).toDS() + assert(ds.collect().head == item) + assert(ds.collectAsList().get(0) == item) + assert(ds.first() == item) + assert(ds.take(1).head == item) + assert(ds.takeAsList(1).get(0) == item) + } + test("as tuple") { val data = Seq(("a", 1), ("b", 2)).toDF("a", "b") checkAnswer( @@ -75,6 +101,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ignore("Dataset should set the resolved encoders internally for maps") { // TODO: Enable this once we fix SPARK-11793. + // We inject a group by here to make sure this test case is future proof + // when we implement better pipelining and local execution mode. val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS() .map(c => ClassData(c.a, c.b + 1)) .groupBy(p => p).count() @@ -219,7 +247,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy function, fatMap") { + test("groupBy function, flatMap") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) }