diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestAvroConversionUtils.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestAvroConversionUtils.scala index bacd44753df35..16df1f869c6bc 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestAvroConversionUtils.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestAvroConversionUtils.scala @@ -18,8 +18,13 @@ package org.apache.hudi +import java.nio.ByteBuffer +import java.util.Objects import org.apache.avro.Schema -import org.apache.spark.sql.types.{DataTypes, StructType, StringType, ArrayType} +import org.apache.avro.generic.GenericData +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, DataTypes, MapType, StringType, StructField, StructType} import org.scalatest.{FunSuite, Matchers} class TestAvroConversionUtils extends FunSuite with Matchers { @@ -377,4 +382,54 @@ class TestAvroConversionUtils extends FunSuite with Matchers { assert(avroSchema.equals(expectedAvroSchema)) } + + test("test converter with binary") { + val avroSchema = new Schema.Parser().parse("{\"type\":\"record\",\"name\":\"h0_record\",\"namespace\":\"hoodie.h0\",\"fields\"" + + ":[{\"name\":\"col9\",\"type\":[\"null\",\"bytes\"],\"default\":null}]}") + val sparkSchema = StructType(List(StructField("col9", BinaryType, nullable = true))) + // create a test record with avroSchema + val avroRecord = new GenericData.Record(avroSchema) + val bb = ByteBuffer.wrap(Array[Byte](97, 48, 53)) + avroRecord.put("col9", bb) + val row1 = AvroConversionUtils.createAvroToInternalRowConverter(avroSchema, sparkSchema).apply(avroRecord).get + val row2 = AvroConversionUtils.createAvroToInternalRowConverter(avroSchema, sparkSchema).apply(avroRecord).get + internalRowCompare(row1, row2, sparkSchema) + } + + private def internalRowCompare(expected: Any, actual: Any, schema: DataType): Unit = { + schema match { + case StructType(fields) => + val expectedRow = expected.asInstanceOf[InternalRow] + val actualRow = actual.asInstanceOf[InternalRow] + fields.zipWithIndex.foreach { case (field, i) => internalRowCompare(expectedRow.get(i, field.dataType), actualRow.get(i, field.dataType), field.dataType) } + case ArrayType(elementType, _) => + val expectedArray = expected.asInstanceOf[ArrayData].toSeq[Any](elementType) + val actualArray = actual.asInstanceOf[ArrayData].toSeq[Any](elementType) + if (expectedArray.size != actualArray.size) { + throw new AssertionError() + } else { + expectedArray.zip(actualArray).foreach { case (e1, e2) => internalRowCompare(e1, e2, elementType) } + } + case MapType(keyType, valueType, _) => + val expectedKeyArray = expected.asInstanceOf[MapData].keyArray() + val expectedValueArray = expected.asInstanceOf[MapData].valueArray() + val actualKeyArray = actual.asInstanceOf[MapData].keyArray() + val actualValueArray = actual.asInstanceOf[MapData].valueArray() + internalRowCompare(expectedKeyArray, actualKeyArray, ArrayType(keyType)) + internalRowCompare(expectedValueArray, actualValueArray, ArrayType(valueType)) + case StringType => if (checkNull(expected, actual) || !expected.toString.equals(actual.toString)) { + throw new AssertionError(String.format("%s is not equals %s", expected.toString, actual.toString)) + } + case BinaryType => if (checkNull(expected, actual) || !expected.asInstanceOf[Array[Byte]].sameElements(actual.asInstanceOf[Array[Byte]])) { + throw new AssertionError(String.format("%s is not equals %s", expected.toString, actual.toString)) + } + case _ => if (!Objects.equals(expected, actual)) { + throw new AssertionError(String.format("%s is not equals %s", expected.toString, actual.toString)) + } + } + } + + private def checkNull(left: Any, right: Any): Boolean = { + (left == null && right != null) || (left == null && right != null) + } } diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 2e0946f1eb989..385577dd30b84 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -146,6 +146,8 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { case b: ByteBuffer => val bytes = new Array[Byte](b.remaining) b.get(bytes) + // Do not forget to reset the position + b.rewind() bytes case b: Array[Byte] => b case other => throw new RuntimeException(s"$other is not a valid avro binary.") diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 717df0f4076ee..5fb6d907bdc82 100644 --- a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -167,6 +167,8 @@ private[sql] class AvroDeserializer(rootAvroType: Schema, case b: ByteBuffer => val bytes = new Array[Byte](b.remaining) b.get(bytes) + // Do not forget to reset the position + b.rewind() bytes case b: Array[Byte] => b case other => throw new RuntimeException(s"$other is not a valid avro binary.") diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index ef9b5909207ca..0b609330756eb 100644 --- a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -181,6 +181,8 @@ private[sql] class AvroDeserializer(rootAvroType: Schema, case b: ByteBuffer => val bytes = new Array[Byte](b.remaining) b.get(bytes) + // Do not forget to reset the position + b.rewind() bytes case b: Array[Byte] => b case other =>