diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index c8b017e25163..79c2255641c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql -import scala.reflect.ClassTag +import scala.reflect.{ClassTag, classTag} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.catalyst.expressions.{DeserializeWithKryo, BoundReference, SerializeWithKryo} +import org.apache.spark.sql.types._ /** * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. @@ -37,7 +38,33 @@ trait Encoder[T] extends Serializable { def clsTag: ClassTag[T] } +/** + * Methods for creating encoders. + */ object Encoders { + + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + */ + def kryo[T: ClassTag]: Encoder[T] = { + val ser = SerializeWithKryo(BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true)) + val deser = DeserializeWithKryo[T](BoundReference(0, BinaryType, nullable = true), classTag[T]) + ExpressionEncoder[T]( + schema = new StructType().add("value", BinaryType), + flat = true, + toRowExpressions = Seq(ser), + fromRowExpression = deser, + clsTag = classTag[T] + ) + } + + /** + * Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + */ + def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 9a1a8f5cbbdc..b977f278c5b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -161,7 +161,9 @@ case class ExpressionEncoder[T]( @transient private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions) - private val inputRow = new GenericMutableRow(1) + + @transient + private lazy val inputRow = new GenericMutableRow(1) @transient private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala index 414adb21168e..55c4ee11b20f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -230,7 +230,7 @@ object ProductEncoder { Invoke(inputObject, "booleanValue", BooleanType) case other => - throw new UnsupportedOperationException(s"Extractor for type $other is not supported") + throw new UnsupportedOperationException(s"Encoder for type $other is not supported") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 5cd19de68391..489c6126f8cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.catalyst.expressions +import scala.language.existentials +import scala.reflect.ClassTag + +import org.apache.spark.SparkConf +import org.apache.spark.serializer.{KryoSerializerInstance, KryoSerializer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} import org.apache.spark.sql.catalyst.util.GenericArrayData - -import scala.language.existentials - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ @@ -514,3 +516,64 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataTy """ } } + +/** Serializes an input object using Kryo serializer. */ +case class SerializeWithKryo(child: Expression) extends UnaryExpression { + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val input = child.gen(ctx) + val kryo = ctx.freshName("kryoSerializer") + val kryoClass = classOf[KryoSerializer].getName + val kryoInstanceClass = classOf[KryoSerializerInstance].getName + val sparkConfClass = classOf[SparkConf].getName + ctx.addMutableState( + kryoInstanceClass, + kryo, + s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();") + + s""" + ${input.code} + final boolean ${ev.isNull} = ${input.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $kryo.serialize(${input.value}, null).array(); + } + """ + } + + override def dataType: DataType = BinaryType +} + +/** + * Deserializes an input object using Kryo serializer. Note that the ClassTag is not an implicit + * parameter because TreeNode cannot copy implicit parameters. + */ +case class DeserializeWithKryo[T](child: Expression, tag: ClassTag[T]) extends UnaryExpression { + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val input = child.gen(ctx) + val kryo = ctx.freshName("kryoSerializer") + val kryoClass = classOf[KryoSerializer].getName + val kryoInstanceClass = classOf[KryoSerializerInstance].getName + val sparkConfClass = classOf[SparkConf].getName + ctx.addMutableState( + kryoInstanceClass, + kryo, + s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();") + + s""" + ${input.code} + final boolean ${ev.isNull} = ${input.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = (${ctx.javaType(dataType)}) + $kryo.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); + } + """ + } + + override def dataType: DataType = ObjectType(tag.runtimeClass) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala index 55821c437068..2729db84897a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.encoders import java.sql.{Date, Timestamp} +import org.apache.spark.sql.Encoders class FlatEncoderSuite extends ExpressionEncoderSuite { encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean") @@ -71,4 +72,21 @@ class FlatEncoderSuite extends ExpressionEncoderSuite { encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with null") encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), FlatEncoder[Map[Int, Map[String, Int]]], "map of map") + + // Kryo encoders + encodeDecodeTest( + "hello", + Encoders.kryo[String].asInstanceOf[ExpressionEncoder[String]], + "kryo string") + encodeDecodeTest( + new NotJavaSerializable(15), + Encoders.kryo[NotJavaSerializable].asInstanceOf[ExpressionEncoder[NotJavaSerializable]], + "kryo object serialization") +} + + +class NotJavaSerializable(val value: Int) { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[NotJavaSerializable].value + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 718ed812dd64..817c20fdbb9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -147,6 +147,12 @@ class Dataset[T] private[sql]( } } + /** + * Returns the number of elements in the [[Dataset]]. + * @since 1.6.0 + */ + def count(): Long = toDF().count() + /* *********************** * * Functional Operations * * *********************** */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 467cd42b9b8d..c66162ee2148 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql - import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index ea29428c5508..a522894c374f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -24,21 +24,6 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -case class ClassData(a: String, b: Int) - -/** - * A class used to test serialization using encoders. This class throws exceptions when using - * Java serialization -- so the only way it can be "serialized" is through our encoders. - */ -case class NonSerializableCaseClass(value: String) extends Externalizable { - override def readExternal(in: ObjectInput): Unit = { - throw new UnsupportedOperationException - } - - override def writeExternal(out: ObjectOutput): Unit = { - throw new UnsupportedOperationException - } -} class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -362,8 +347,63 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer(joined, ("2", 2)) } + ignore("self join") { + val ds = Seq("1", "2").toDS().as("a") + val joined = ds.joinWith(ds, lit(true)) + checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")) + } + test("toString") { val ds = Seq((1, 2)).toDS() assert(ds.toString == "[_1: int, _2: int]") } + + test("kryo encoder") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2))) + + assert(ds.groupBy(p => p).count().collect().toSeq == + Seq((KryoData(1), 1L), (KryoData(2), 1L))) + } + + ignore("kryo encoder self join") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2))) + assert(ds.joinWith(ds, lit(true)).collect().toSet == + Set( + (KryoData(1), KryoData(1)), + (KryoData(1), KryoData(2)), + (KryoData(2), KryoData(1)), + (KryoData(2), KryoData(2)))) + } +} + + +case class ClassData(a: String, b: Int) + +/** + * A class used to test serialization using encoders. This class throws exceptions when using + * Java serialization -- so the only way it can be "serialized" is through our encoders. + */ +case class NonSerializableCaseClass(value: String) extends Externalizable { + override def readExternal(in: ObjectInput): Unit = { + throw new UnsupportedOperationException + } + + override def writeExternal(out: ObjectOutput): Unit = { + throw new UnsupportedOperationException + } +} + +/** Used to test Kryo encoder. */ +class KryoData(val a: Int) { + override def equals(other: Any): Boolean = { + a == other.asInstanceOf[KryoData].a + } + override def hashCode: Int = a + override def toString: String = s"KryoData($a)" +} + +object KryoData { + def apply(a: Int): KryoData = new KryoData(a) }