diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/avro/TestAvroSerDe.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/avro/TestAvroSerDe.scala index 069d56f282d27..4e0e024e37afd 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/avro/TestAvroSerDe.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/avro/TestAvroSerDe.scala @@ -21,14 +21,14 @@ import org.apache.avro.generic.GenericData import org.apache.hudi.SparkAdapterSupport import org.apache.hudi.avro.model.{HoodieMetadataColumnStats, IntWrapper} import org.apache.spark.sql.avro.SchemaConverters.SchemaType -import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.{assertEquals, assertNotEquals} import org.junit.jupiter.api.Test class TestAvroSerDe extends SparkAdapterSupport { @Test def testAvroUnionSerDe(): Unit = { - val originalAvroRecord = { + val originalAvroRecord1 = { val minValue = new GenericData.Record(IntWrapper.SCHEMA$) minValue.put("value", 9) val maxValue = new GenericData.Record(IntWrapper.SCHEMA$) @@ -47,15 +47,25 @@ class TestAvroSerDe extends SparkAdapterSupport { record } + val originalAvroRecord2 = GenericData.get.deepCopy(originalAvroRecord1.getSchema, originalAvroRecord1) + originalAvroRecord2.put("totalUncompressedSize", 55L) + assertNotEquals(originalAvroRecord1, originalAvroRecord2) + val avroSchema = HoodieMetadataColumnStats.SCHEMA$ val SchemaType(catalystSchema, _) = SchemaConverters.toSqlType(avroSchema) val deserializer = sparkAdapter.createAvroDeserializer(avroSchema, catalystSchema) val serializer = sparkAdapter.createAvroSerializer(catalystSchema, avroSchema, nullable = false) - val row = deserializer.deserialize(originalAvroRecord).get - val deserializedAvroRecord = serializer.serialize(row) + val row1 = deserializer.deserialize(originalAvroRecord1).get + val row2 = deserializer.deserialize(originalAvroRecord2).get + assertNotEquals(row1, row2) + + val deserializedAvroRecord1 = serializer.serialize(row1) + val deserializedAvroRecord2 = serializer.serialize(row2) - assertEquals(originalAvroRecord, deserializedAvroRecord) + assertNotEquals(originalAvroRecord1, deserializedAvroRecord2) + assertEquals(originalAvroRecord1, deserializedAvroRecord1) + assertEquals(originalAvroRecord2, deserializedAvroRecord2) } }