Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ private[image] class ImageFileFormat extends FileFormat with DataSourceRegister
if (requiredSchema.isEmpty) {
filteredResult.map(_ => emptyUnsafeRow)
} else {
val converter = RowEncoder(requiredSchema)
filteredResult.map(row => converter.toRow(row))
val toRow = RowEncoder(requiredSchema).createSerializer()
filteredResult.map(row => toRow(row))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ private[libsvm] class LibSVMFileFormat
LabeledPoint(label, Vectors.sparse(numFeatures, indices, values))
}

val converter = RowEncoder(dataSchema)
val toRow = RowEncoder(dataSchema).createSerializer()
val fullOutput = dataSchema.map { f =>
AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
}
Expand All @@ -178,7 +178,7 @@ private[libsvm] class LibSVMFileFormat

points.map { pt =>
val features = if (isSparse) pt.features.toSparse else pt.features.toDense
requiredColumns(converter.toRow(Row(pt.label, features)))
requiredColumns(toRow(Row(pt.label, features)))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,22 @@ object UDTSerializationBenchmark extends BenchmarkBase {
val iters = 1e2.toInt
val numRows = 1e3.toInt

val encoder = ExpressionEncoder[Vector].resolveAndBind()
val encoder = ExpressionEncoder[Vector]().resolveAndBind()
val toRow = encoder.createSerializer()
val fromRow = encoder.createDeserializer()

val vectors = (1 to numRows).map { i =>
Vectors.dense(Array.fill(1e5.toInt)(1.0 * i))
}.toArray
val rows = vectors.map(encoder.toRow)
val rows = vectors.map(toRow)

val benchmark = new Benchmark("VectorUDT de/serialization", numRows, iters, output = output)

benchmark.addCase("serialize") { _ =>
var sum = 0
var i = 0
while (i < numRows) {
sum += encoder.toRow(vectors(i)).numFields
sum += toRow(vectors(i)).numFields
i += 1
}
}
Expand All @@ -60,7 +62,7 @@ object UDTSerializationBenchmark extends BenchmarkBase {
var sum = 0
var i = 0
while (i < numRows) {
sum += encoder.fromRow(rows(i)).numActives
sum += fromRow(rows(i)).numActives
i += 1
}
}
Expand Down
10 changes: 0 additions & 10 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,6 @@ import org.apache.spark.sql.types._
* Encoders.bean(MyClass.class);
* }}}
*
* == Implementation ==
* - Encoders are not required to be thread-safe and thus they do not need to use locks to guard
* against concurrent access if they reuse internal buffers to improve performance.
*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this means Encoders must be thread-safe? Do we need explicit comment for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be thread-safe.

* @since 1.6.0
*/
@implicitNotFound("Unable to find encoder for type ${T}. An implicit Encoder[${T}] is needed to " +
Expand All @@ -76,10 +72,4 @@ trait Encoder[T] extends Serializable {
* A ClassTag that can be used to construct an Array to contain a collection of `T`.
*/
def clsTag: ClassTag[T]

/**
* Create a copied [[Encoder]]. The implementation may just copy internal reusable fields to speed
* up the [[Encoder]] creation.
*/
def makeCopy: Encoder[T]
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag}
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance}
Expand Down Expand Up @@ -162,6 +163,18 @@ object ExpressionEncoder {
e4: ExpressionEncoder[T4],
e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]

/**
* Function that deserializes an [[InternalRow]] into an object of type `T`. Instances of this
* class are not meant to be thread-safe.
*/
abstract class Deserializer[T] extends (InternalRow => T)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation note. I opted to go with abstract classes so we can get monomorphic call sites in many cases.


/**
* Function that serializesa an object of type `T` to an [[InternalRow]]. Instances of this
* class are not meant to be thread-safe.
*/
abstract class Serializer[T] extends (T => InternalRow)
}

/**
Expand Down Expand Up @@ -302,25 +315,22 @@ case class ExpressionEncoder[T](
}

@transient
private lazy val extractProjection = GenerateUnsafeProjection.generate({
private lazy val optimizedDeserializer: Seq[Expression] = {
// When using `ExpressionEncoder` directly, we will skip the normal query processing steps
// (analyzer, optimizer, etc.). Here we apply the ReassignLambdaVariableID rule, as it's
// important to codegen performance.
val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(serializer))
val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(Seq(deserializer)))
optimizedPlan.asInstanceOf[DummyExpressionHolder].exprs
})

@transient
private lazy val inputRow = new GenericInternalRow(1)
}

@transient
private lazy val constructProjection = SafeProjection.create({
private lazy val optimizedSerializer = {
// When using `ExpressionEncoder` directly, we will skip the normal query processing steps
// (analyzer, optimizer, etc.). Here we apply the ReassignLambdaVariableID rule, as it's
// important to codegen performance.
val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(Seq(deserializer)))
val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(serializer))
optimizedPlan.asInstanceOf[DummyExpressionHolder].exprs
})
}

/**
* Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form
Expand All @@ -332,30 +342,43 @@ case class ExpressionEncoder[T](
}

/**
* Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
* toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
* copy the result before making another call if required.
* Create a serializer that can convert an object of type `T` to a Spark SQL Row.
*
* Note that the returned [[Serializer]] is not thread safe. Multiple calls to
* `serializer.apply(..)` are allowed to return the same actual [[InternalRow]] object. Thus,
* the caller should copy the result before making another call if required.
*/
def toRow(t: T): InternalRow = try {
inputRow(0) = t
extractProjection(inputRow)
} catch {
case e: Exception =>
throw new RuntimeException(s"Error while encoding: $e\n" +
def createSerializer(): Serializer[T] = new Serializer[T] {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This currently relies on use serializing the enclosing encoder as well. We technically don't need the entire encoder but only a couple fields. I could move this class into the companion object and just use the fields I need.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea that would be better

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also thinking about if we just need to pass the original (de)serializer expressions and do the optimization inside Serializer and Deserializer lazily.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some cost to the optimization. So I would like to do it only once.

private val inputRow = new GenericInternalRow(1)

private val extractProjection = GenerateUnsafeProjection.generate(optimizedSerializer)

override def apply(t: T): InternalRow = try {
inputRow(0) = t
extractProjection(inputRow)
} catch {
case e: Exception =>
throw new RuntimeException(s"Error while encoding: $e\n" +
s"${serializer.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e)
}
}

/**
* Returns an object of type `T`, extracting the required values from the provided row. Note that
* you must `resolveAndBind` an encoder to a specific schema before you can call this
* function.
* Create a deserializer that can convert a Spark SQL Row into an object of type `T`.
*
* Note that you must `resolveAndBind` an encoder to a specific schema before you can create a
* deserializer.
*/
def fromRow(row: InternalRow): T = try {
constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
} catch {
case e: Exception =>
throw new RuntimeException(s"Error while decoding: $e\n" +
s"${deserializer.simpleString(SQLConf.get.maxToStringFields)}", e)
def createDeserializer(): Deserializer[T] = new Deserializer[T] {
private val constructProjection = SafeProjection.create(optimizedDeserializer)

override def apply(row: InternalRow): T = try {
constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
} catch {
case e: Exception =>
throw new RuntimeException(s"Error while decoding: $e\n" +
s"${deserializer.simpleString(SQLConf.get.maxToStringFields)}", e)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As pointed already, some fields like deserializer are in enclosing encoder, and so currently looks like we will serialize entire encoder? Actually we did serialize entire encoder currently but yea it is better we can get rid of unnecessary.

}
}

/**
Expand Down Expand Up @@ -383,8 +406,6 @@ case class ExpressionEncoder[T](
.map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")

override def toString: String = s"class[$schemaString]"

override def makeCopy: ExpressionEncoder[T] = copy()
}

// A dummy logical plan that can hold expressions and go through optimizer rules.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ case class ScalaUDF(
} else {
val encoder = inputEncoders(i)
if (encoder.isDefined && encoder.get.isSerializedAsStructForTopLevel) {
val enc = encoder.get.resolveAndBind()
row: Any => enc.fromRow(row.asInstanceOf[InternalRow])
val fromRow = encoder.get.resolveAndBind().createDeserializer()
row: Any => fromRow(row.asInstanceOf[InternalRow])
} else {
CatalystTypeConverters.createToScalaConverter(dataType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ object HashBenchmark extends BenchmarkBase {
def test(name: String, schema: StructType, numRows: Int, iters: Int): Unit = {
runBenchmark(name) {
val generator = RandomDataGenerator.forType(schema, nullable = false).get
val encoder = RowEncoder(schema)
val toRow = RowEncoder(schema).createSerializer()
val attrs = schema.toAttributes
val safeProjection = GenerateSafeProjection.generate(attrs, attrs)

val rows = (1 to numRows).map(_ =>
// The output of encoder is UnsafeRow, use safeProjection to turn in into safe format.
safeProjection(encoder.toRow(generator().asInstanceOf[Row])).copy()
safeProjection(toRow(generator().asInstanceOf[Row])).copy()
).toArray

val benchmark = new Benchmark("Hash For " + name, iters * numRows.toLong, output = output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ object UnsafeProjectionBenchmark extends BenchmarkBase {

def generateRows(schema: StructType, numRows: Int): Array[InternalRow] = {
val generator = RandomDataGenerator.forType(schema, nullable = false).get
val encoder = RowEncoder(schema)
(1 to numRows).map(_ => encoder.toRow(generator().asInstanceOf[Row]).copy()).toArray
val toRow = RowEncoder(schema).createSerializer()
(1 to numRows).map(_ => toRow(generator().asInstanceOf[Row]).copy()).toArray
}

override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
Expand All @@ -42,37 +43,44 @@ case class NestedArrayClass(nestedArr: Array[ArrayClass])
class EncoderResolutionSuite extends PlanTest {
private val str = UTF8String.fromString("hello")

def testFromRow[T](
encoder: ExpressionEncoder[T],
attributes: Seq[Attribute],
row: InternalRow): Unit = {
encoder.resolveAndBind(attributes).createDeserializer()(row)
}

test("real type doesn't match encoder schema but they are compatible: product") {
val encoder = ExpressionEncoder[StringLongClass]

// int type can be up cast to long type
val attrs1 = Seq('a.string, 'b.int)
encoder.resolveAndBind(attrs1).fromRow(InternalRow(str, 1))
testFromRow(encoder, attrs1, InternalRow(str, 1))

// int type can be up cast to string type
val attrs2 = Seq('a.int, 'b.long)
encoder.resolveAndBind(attrs2).fromRow(InternalRow(1, 2L))
testFromRow(encoder, attrs2, InternalRow(1, 2L))
}

test("real type doesn't match encoder schema but they are compatible: nested product") {
val encoder = ExpressionEncoder[ComplexClass]
val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
encoder.resolveAndBind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
testFromRow(encoder, attrs, InternalRow(1, InternalRow(2, 3L)))
}

test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
val encoder = ExpressionEncoder.tuple(
ExpressionEncoder[StringLongClass],
ExpressionEncoder[Long])
val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
testFromRow(encoder, attrs, InternalRow(InternalRow(str, 1.toByte), 2))
}

test("real type doesn't match encoder schema but they are compatible: primitive array") {
val encoder = ExpressionEncoder[PrimitiveArrayClass]
val attrs = Seq('arr.array(IntegerType))
val array = new GenericArrayData(Array(1, 2, 3))
encoder.resolveAndBind(attrs).fromRow(InternalRow(array))
testFromRow(encoder, attrs, InternalRow(array))
}

test("the real type is not compatible with encoder schema: primitive array") {
Expand All @@ -93,7 +101,7 @@ class EncoderResolutionSuite extends PlanTest {
val encoder = ExpressionEncoder[ArrayClass]
val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int")))
val array = new GenericArrayData(Array(InternalRow(1, 2, 3)))
encoder.resolveAndBind(attrs).fromRow(InternalRow(array))
testFromRow(encoder, attrs, InternalRow(array))
}

test("real type doesn't match encoder schema but they are compatible: nested array") {
Expand All @@ -103,7 +111,7 @@ class EncoderResolutionSuite extends PlanTest {
val attrs = Seq('nestedArr.array(et))
val innerArr = new GenericArrayData(Array(InternalRow(1, 2, 3)))
val outerArr = new GenericArrayData(Array(InternalRow(innerArr)))
encoder.resolveAndBind(attrs).fromRow(InternalRow(outerArr))
testFromRow(encoder, attrs, InternalRow(outerArr))
}

test("the real type is not compatible with encoder schema: non-array field") {
Expand Down Expand Up @@ -142,14 +150,14 @@ class EncoderResolutionSuite extends PlanTest {
val attrs = 'a.array(IntegerType) :: Nil

// It should pass analysis
val bound = encoder.resolveAndBind(attrs)
val fromRow = encoder.resolveAndBind(attrs).createDeserializer()

// If no null values appear, it should work fine
bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2))))
fromRow(InternalRow(new GenericArrayData(Array(1, 2))))

// If there is null value, it should throw runtime exception
val e = intercept[RuntimeException] {
bound.fromRow(InternalRow(new GenericArrayData(Array(1, null))))
fromRow(InternalRow(new GenericArrayData(Array(1, null))))
}
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
}
Expand Down
Loading