Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

package org.apache.spark.sql

import scala.reflect.ClassTag
import scala.reflect.{ClassTag, classTag}

import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.catalyst.expressions.{DeserializeWithKryo, BoundReference, SerializeWithKryo}
import org.apache.spark.sql.types._

/**
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
Expand All @@ -37,7 +38,33 @@ trait Encoder[T] extends Serializable {
def clsTag: ClassTag[T]
}

/**
* Methods for creating encoders.
*/
object Encoders {

/**
* (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
* This encoder maps T into a single byte array (binary) field.
*/
def kryo[T: ClassTag]: Encoder[T] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about putting this method into ExpressionEncoder? I think Encoders is only used at java side as the lack of implicit magic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then we can change the return type to ExpressionEncoder[T] and make the tests less verbose with asInstanseOf

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we want to use kryo in scala?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, this can't be done by implicit, I was wrong, nvm.

val ser = SerializeWithKryo(BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true))
val deser = DeserializeWithKryo[T](BoundReference(0, BinaryType, nullable = true), classTag[T])
ExpressionEncoder[T](
schema = new StructType().add("value", BinaryType),
flat = true,
toRowExpressions = Seq(ser),
fromRowExpression = deser,
clsTag = classTag[T]
)
}

/**
* Creates an encoder that serializes objects of type T using Kryo.
* This encoder maps T into a single byte array (binary) field.
*/
def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz))

def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true)
def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true)
def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ case class ExpressionEncoder[T](

@transient
private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions)
private val inputRow = new GenericMutableRow(1)

@transient
private lazy val inputRow = new GenericMutableRow(1)

@transient
private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ object ProductEncoder {
Invoke(inputObject, "booleanValue", BooleanType)

case other =>
throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
throw new UnsupportedOperationException(s"Encoder for type $other is not supported")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

package org.apache.spark.sql.catalyst.expressions

import scala.language.existentials
import scala.reflect.ClassTag

import org.apache.spark.SparkConf
import org.apache.spark.serializer.{KryoSerializerInstance, KryoSerializer}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
import org.apache.spark.sql.catalyst.util.GenericArrayData

import scala.language.existentials

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -514,3 +516,64 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataTy
"""
}
}

/** Serializes an input object using Kryo serializer. */
case class SerializeWithKryo(child: Expression) extends UnaryExpression {

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")

override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val input = child.gen(ctx)
val kryo = ctx.freshName("kryoSerializer")
val kryoClass = classOf[KryoSerializer].getName
val kryoInstanceClass = classOf[KryoSerializerInstance].getName
val sparkConfClass = classOf[SparkConf].getName
ctx.addMutableState(
kryoInstanceClass,
kryo,
s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();")

s"""
${input.code}
final boolean ${ev.isNull} = ${input.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $kryo.serialize(${input.value}, null).array();
}
"""
}

override def dataType: DataType = BinaryType
}

/**
* Deserializes an input object using Kryo serializer. Note that the ClassTag is not an implicit
* parameter because TreeNode cannot copy implicit parameters.
*/
case class DeserializeWithKryo[T](child: Expression, tag: ClassTag[T]) extends UnaryExpression {

override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val input = child.gen(ctx)
val kryo = ctx.freshName("kryoSerializer")
val kryoClass = classOf[KryoSerializer].getName
val kryoInstanceClass = classOf[KryoSerializerInstance].getName
val sparkConfClass = classOf[SparkConf].getName
ctx.addMutableState(
kryoInstanceClass,
kryo,
s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();")

s"""
${input.code}
final boolean ${ev.isNull} = ${input.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = (${ctx.javaType(dataType)})
$kryo.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null);
}
"""
}

override def dataType: DataType = ObjectType(tag.runtimeClass)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.encoders

import java.sql.{Date, Timestamp}
import org.apache.spark.sql.Encoders

class FlatEncoderSuite extends ExpressionEncoderSuite {
encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean")
Expand Down Expand Up @@ -71,4 +72,21 @@ class FlatEncoderSuite extends ExpressionEncoderSuite {
encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with null")
encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)),
FlatEncoder[Map[Int, Map[String, Int]]], "map of map")

// Kryo encoders
encodeDecodeTest(
"hello",
Encoders.kryo[String].asInstanceOf[ExpressionEncoder[String]],
"kryo string")
encodeDecodeTest(
new NotJavaSerializable(15),
Encoders.kryo[NotJavaSerializable].asInstanceOf[ExpressionEncoder[NotJavaSerializable]],
"kryo object serialization")
}


class NotJavaSerializable(val value: Int) {
override def equals(other: Any): Boolean = {
this.value == other.asInstanceOf[NotJavaSerializable].value
}
}
6 changes: 6 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ class Dataset[T] private[sql](
}
}

/**
* Returns the number of elements in the [[Dataset]].
* @since 1.6.0
*/
def count(): Long = toDF().count()

/* *********************** *
* Functional Operations *
* *********************** */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql


import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
Expand Down
70 changes: 55 additions & 15 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,6 @@ import scala.language.postfixOps
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext

case class ClassData(a: String, b: Int)

/**
* A class used to test serialization using encoders. This class throws exceptions when using
* Java serialization -- so the only way it can be "serialized" is through our encoders.
*/
case class NonSerializableCaseClass(value: String) extends Externalizable {
override def readExternal(in: ObjectInput): Unit = {
throw new UnsupportedOperationException
}

override def writeExternal(out: ObjectOutput): Unit = {
throw new UnsupportedOperationException
}
}

class DatasetSuite extends QueryTest with SharedSQLContext {
import testImplicits._
Expand Down Expand Up @@ -362,8 +347,63 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkAnswer(joined, ("2", 2))
}

ignore("self join") {
val ds = Seq("1", "2").toDS().as("a")
val joined = ds.joinWith(ds, lit(true))
checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2"))
}

test("toString") {
val ds = Seq((1, 2)).toDS()
assert(ds.toString == "[_1: int, _2: int]")
}

test("kryo encoder") {
implicit val kryoEncoder = Encoders.kryo[KryoData]
val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2)))

assert(ds.groupBy(p => p).count().collect().toSeq ==
Seq((KryoData(1), 1L), (KryoData(2), 1L)))
}

ignore("kryo encoder self join") {
implicit val kryoEncoder = Encoders.kryo[KryoData]
val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2)))
assert(ds.joinWith(ds, lit(true)).collect().toSet ==
Set(
(KryoData(1), KryoData(1)),
(KryoData(1), KryoData(2)),
(KryoData(2), KryoData(1)),
(KryoData(2), KryoData(2))))
}
}


case class ClassData(a: String, b: Int)

/**
* A class used to test serialization using encoders. This class throws exceptions when using
* Java serialization -- so the only way it can be "serialized" is through our encoders.
*/
case class NonSerializableCaseClass(value: String) extends Externalizable {
override def readExternal(in: ObjectInput): Unit = {
throw new UnsupportedOperationException
}

override def writeExternal(out: ObjectOutput): Unit = {
throw new UnsupportedOperationException
}
}

/** Used to test Kryo encoder. */
class KryoData(val a: Int) {
override def equals(other: Any): Boolean = {
a == other.asInstanceOf[KryoData].a
}
override def hashCode: Int = a
override def toString: String = s"KryoData($a)"
}

object KryoData {
def apply(a: Int): KryoData = new KryoData(a)
}