Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.avro.generic.GenericData;
import org.apache.iceberg.Files;
Expand All @@ -31,17 +33,17 @@
import org.apache.iceberg.io.FileAppender;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData;
import org.apache.spark.sql.catalyst.util.GenericArrayData;
import org.apache.spark.unsafe.types.UTF8String;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

import static org.apache.avro.Schema.Type.INT;
import static org.apache.avro.Schema.Type.NULL;
import static org.apache.iceberg.spark.SparkSchemaUtil.convert;
import static org.apache.iceberg.spark.data.TestHelpers.assertEquals;

public class TestSparkAvroReaderForFieldsWithDefaultValue {

Expand All @@ -50,72 +52,150 @@ public class TestSparkAvroReaderForFieldsWithDefaultValue {

@Test
public void testAvroDefaultValues() throws IOException {
String indexFiledName = "index";
String nullableFiledName = "optionalFieldWithDefault";
String requiredFiledName = "requiredFieldWithDefault";
int defaultValue = -1;

// write records with initial writeSchema
org.apache.avro.Schema writeSchema = org.apache.avro.Schema.createRecord("root", null, null, false,
ImmutableList.of(new org.apache.avro.Schema.Field(indexFiledName, org.apache.avro.Schema.create(INT),
null, null), new org.apache.avro.Schema.Field(nullableFiledName,
org.apache.avro.Schema.createUnion(org.apache.avro.Schema.create(INT),
org.apache.avro.Schema.create(NULL)), null, defaultValue)));
String writeSchemaString = "{\n" +
" \"namespace\": \"com.n1\",\n" +
" \"type\": \"record\",\n" +
" \"name\": \"n1\",\n" +
" \"fields\": [\n" +
" {\n" +
" \"name\": \"f0\",\n" +
" \"type\": \"string\"\n" +
" }\n" +
" ]\n" +
"}";

org.apache.avro.Schema writeSchema = new org.apache.avro.Schema.Parser().parse(writeSchemaString);
org.apache.iceberg.Schema icebergWriteSchema = AvroSchemaUtil.toIceberg(writeSchema);

Schema icebergWriteSchema = AvroSchemaUtil.toIceberg(writeSchema);
List<GenericData.Record> expected = RandomData.generateList(icebergWriteSchema, 2, 0L);

File testFile = temp.newFile();
Assert.assertTrue("Delete should succeed", testFile.delete());

// write records with initial writeSchema
try (FileAppender<GenericData.Record> writer = Avro.write(Files.localOutput(testFile))
.schema(icebergWriteSchema)
.named("test")
.build()) {
.schema(icebergWriteSchema)
.named("test")
.build()) {
for (GenericData.Record rec : expected) {
writer.add(rec);
}
}

// evolve schema by adding a required field with default value
org.apache.avro.Schema evolvedSchema = org.apache.avro.Schema.createRecord("root", null, null, false,
ImmutableList.of(new org.apache.avro.Schema.Field(indexFiledName, org.apache.avro.Schema.create(INT),
null, null),
new org.apache.avro.Schema.Field(nullableFiledName,
org.apache.avro.Schema.createUnion(org.apache.avro.Schema.create(INT),
org.apache.avro.Schema.create(NULL)), null, defaultValue),
new org.apache.avro.Schema.Field(requiredFiledName, org.apache.avro.Schema.create(INT), null, defaultValue)
));
String evolvedSchemaString = "{\n" +
" \"namespace\": \"com.n1\",\n" +
" \"type\": \"record\",\n" +
" \"name\": \"n1\",\n" +
" \"fields\": [\n" +
" {\n" +
" \"name\": \"f0\",\n" +
" \"type\": \"string\"\n" +
" },\n" +
" {\n" +
" \"name\": \"f1\",\n" +
" \"type\": \"string\",\n" +
" \"default\": \"foo\"\n" +
" },\n" +
" {\n" +
" \"name\": \"f2\",\n" +
" \"type\": \"int\",\n" +
" \"default\": 1\n" +
" },\n" +
" {\n" +
" \"name\": \"f3\",\n" +
" \"type\": {\n" +
" \"type\": \"map\",\n" +
" \"values\" : \"int\"\n" +
" },\n" +
" \"default\": {\"a\": 1}\n" +
" },\n" +
" {\n" +
" \"name\": \"f4\",\n" +
" \"type\": {\n" +
" \"type\": \"array\",\n" +
" \"items\" : \"int\"\n" +
" },\n" +
" \"default\": [1, 2, 3]\n" +
" },\n" +
" {\n" +
" \"name\": \"f5\",\n" +
" \"type\": {\n" +
" \"type\": \"record\",\n" +
" \"name\": \"F5\",\n" +
" \"fields\" : [\n" +
" {\"name\": \"ff1\", \"type\": \"long\"},\n" +
" {\"name\": \"ff2\", \"type\": \"string\"}\n" +
" ]\n" +
" },\n" +
" \"default\": {\n" +
" \"ff1\": 999,\n" +
" \"ff2\": \"foo\"\n" +
" }\n" +
" },\n" +
" {\n" +
" \"name\": \"f6\",\n" +
" \"type\": {\n" +
" \"type\": \"map\",\n" +
" \"values\": {\n" +
" \"type\": \"array\",\n" +
" \"items\" : \"int\"\n" +
" }\n" +
" },\n" +
" \"default\": {\"key\": [1, 2, 3]}\n" +
" },\n" +
" {\n" +
" \"name\": \"f7\",\n" +
" \"type\": {\n" +
" \"type\": \"fixed\",\n" +
" \"name\": \"md5\",\n" +
" \"size\": 2\n" +
" },\n" +
" \"default\": \"FF\"\n" +
" }\n" +
" ]\n" +
"}";
org.apache.avro.Schema evolvedSchema = new org.apache.avro.Schema.Parser().parse(evolvedSchemaString);

// read written rows with evolved schema
List<InternalRow> rows;
Schema icebergReadSchema = AvroSchemaUtil.toIceberg(evolvedSchema);
// read written rows with evolved schema
try (AvroIterable<InternalRow> reader = Avro.read(Files.localInput(testFile))
.createReaderFunc(SparkAvroReader::new)
.project(icebergReadSchema)
.build()) {
.createReaderFunc(SparkAvroReader::new)
.project(icebergReadSchema)
.build()) {
rows = Lists.newArrayList(reader);
}

// validate all rows, and all fields are read properly
Assert.assertNotNull(rows);
Assert.assertEquals(expected.size(), rows.size());
for (int row = 0; row < expected.size(); row++) {
GenericData.Record expectedRow = expected.get(row);
InternalRow actualRow = rows.get(row);
List<Types.NestedField> fields = icebergReadSchema.asStruct().fields();

for (int i = 0; i < fields.size(); i += 1) {
Object expectedValue = null;
if (i >= writeSchema.getFields().size() && fields.get(i).hasDefaultValue()) {
expectedValue = fields.get(i).getDefaultValue();
} else if (i < writeSchema.getFields().size()) {
expectedValue = expectedRow.get(i);
}
Type fieldType = fields.get(i).type();
Object actualValue = actualRow.isNullAt(i) ? null : actualRow.get(i, convert(fieldType));
Assert.assertEquals(expectedValue, actualValue);
}
final InternalRow expectedRow = new GenericInternalRow(8);
expectedRow.update(0, UTF8String.fromString((String) expected.get(row).get(0)));
expectedRow.update(1, UTF8String.fromString("foo"));
expectedRow.update(2, 1);
expectedRow.update(3, new ArrayBasedMapData(
new GenericArrayData(Arrays.asList(UTF8String.fromString("a"))),
new GenericArrayData(Arrays.asList(1))));
expectedRow.update(4, new GenericArrayData(ImmutableList.of(1, 2, 3).toArray()));

final InternalRow nestedStructData = new GenericInternalRow(2);
nestedStructData.update(0, 999L);
nestedStructData.update(1, UTF8String.fromString("foo"));
expectedRow.update(5, nestedStructData);

List<GenericArrayData> listOfLists = new ArrayList<GenericArrayData>(1);
listOfLists.add(new GenericArrayData(ImmutableList.of(1, 2, 3).toArray()));
expectedRow.update(6, new ArrayBasedMapData(
new GenericArrayData(Arrays.asList(UTF8String.fromString("key"))),
new GenericArrayData(listOfLists.toArray())));

byte[] objGUIDByteArr = "FF".getBytes("UTF-8");
expectedRow.update(7, objGUIDByteArr);
assertEquals(icebergReadSchema, actualRow, expectedRow);

}
}
}
Expand Down