Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@

package org.apache.hudi

import java.nio.ByteBuffer
import java.util.Objects
import org.apache.avro.Schema
import org.apache.spark.sql.types.{DataTypes, StructType, StringType, ArrayType}
import org.apache.avro.generic.GenericData
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, DataTypes, MapType, StringType, StructField, StructType}
import org.scalatest.{FunSuite, Matchers}

class TestAvroConversionUtils extends FunSuite with Matchers {
Expand Down Expand Up @@ -377,4 +382,54 @@ class TestAvroConversionUtils extends FunSuite with Matchers {

assert(avroSchema.equals(expectedAvroSchema))
}

test("test converter with binary") {
val avroSchema = new Schema.Parser().parse("{\"type\":\"record\",\"name\":\"h0_record\",\"namespace\":\"hoodie.h0\",\"fields\""
+ ":[{\"name\":\"col9\",\"type\":[\"null\",\"bytes\"],\"default\":null}]}")
val sparkSchema = StructType(List(StructField("col9", BinaryType, nullable = true)))
// create a test record with avroSchema
val avroRecord = new GenericData.Record(avroSchema)
val bb = ByteBuffer.wrap(Array[Byte](97, 48, 53))
avroRecord.put("col9", bb)
val row1 = AvroConversionUtils.createAvroToInternalRowConverter(avroSchema, sparkSchema).apply(avroRecord).get
val row2 = AvroConversionUtils.createAvroToInternalRowConverter(avroSchema, sparkSchema).apply(avroRecord).get
internalRowCompare(row1, row2, sparkSchema)
}

private def internalRowCompare(expected: Any, actual: Any, schema: DataType): Unit = {
schema match {
case StructType(fields) =>
val expectedRow = expected.asInstanceOf[InternalRow]
val actualRow = actual.asInstanceOf[InternalRow]
fields.zipWithIndex.foreach { case (field, i) => internalRowCompare(expectedRow.get(i, field.dataType), actualRow.get(i, field.dataType), field.dataType) }
case ArrayType(elementType, _) =>
val expectedArray = expected.asInstanceOf[ArrayData].toSeq[Any](elementType)
val actualArray = actual.asInstanceOf[ArrayData].toSeq[Any](elementType)
if (expectedArray.size != actualArray.size) {
throw new AssertionError()
} else {
expectedArray.zip(actualArray).foreach { case (e1, e2) => internalRowCompare(e1, e2, elementType) }
}
case MapType(keyType, valueType, _) =>
val expectedKeyArray = expected.asInstanceOf[MapData].keyArray()
val expectedValueArray = expected.asInstanceOf[MapData].valueArray()
val actualKeyArray = actual.asInstanceOf[MapData].keyArray()
val actualValueArray = actual.asInstanceOf[MapData].valueArray()
internalRowCompare(expectedKeyArray, actualKeyArray, ArrayType(keyType))
internalRowCompare(expectedValueArray, actualValueArray, ArrayType(valueType))
case StringType => if (checkNull(expected, actual) || !expected.toString.equals(actual.toString)) {
throw new AssertionError(String.format("%s is not equals %s", expected.toString, actual.toString))
}
case BinaryType => if (checkNull(expected, actual) || !expected.asInstanceOf[Array[Byte]].sameElements(actual.asInstanceOf[Array[Byte]])) {
throw new AssertionError(String.format("%s is not equals %s", expected.toString, actual.toString))
}
case _ => if (!Objects.equals(expected, actual)) {
throw new AssertionError(String.format("%s is not equals %s", expected.toString, actual.toString))
}
}
}

private def checkNull(left: Any, right: Any): Boolean = {
(left == null && right != null) || (left == null && right != null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
case b: ByteBuffer =>
val bytes = new Array[Byte](b.remaining)
b.get(bytes)
// Do not forget to reset the position
b.rewind()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The clazz is copied from spark, does spark have the same fix ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not yet

Comment on lines +149 to +150
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Do not forget to reset the position

this would be a helpful note if you can add more details, or better have a UT cover the problem's case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add an ut about it

bytes
case b: Array[Byte] => b
case other => throw new RuntimeException(s"$other is not a valid avro binary.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ private[sql] class AvroDeserializer(rootAvroType: Schema,
case b: ByteBuffer =>
val bytes = new Array[Byte](b.remaining)
b.get(bytes)
// Do not forget to reset the position
b.rewind()
bytes
case b: Array[Byte] => b
case other => throw new RuntimeException(s"$other is not a valid avro binary.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ private[sql] class AvroDeserializer(rootAvroType: Schema,
case b: ByteBuffer =>
val bytes = new Array[Byte](b.remaining)
b.get(bytes)
// Do not forget to reset the position
b.rewind()
bytes
case b: Array[Byte] => b
case other =>
Expand Down