Skip to content

Commit c6afc1b

Browse files
Push responsibility up to client code for deciding what is a valid Avro enum
1 parent f4134d1 commit c6afc1b

File tree

5 files changed

+94
-160
lines changed

5 files changed

+94
-160
lines changed

adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java

Lines changed: 18 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.arrow.adapter.avro.producers.AvroBigIntProducer;
2525
import org.apache.arrow.adapter.avro.producers.AvroBooleanProducer;
2626
import org.apache.arrow.adapter.avro.producers.AvroBytesProducer;
27+
import org.apache.arrow.adapter.avro.producers.AvroEnumProducer;
2728
import org.apache.arrow.adapter.avro.producers.AvroFixedSizeBinaryProducer;
2829
import org.apache.arrow.adapter.avro.producers.AvroFixedSizeListProducer;
2930
import org.apache.arrow.adapter.avro.producers.AvroFloat2Producer;
@@ -62,6 +63,7 @@
6263
import org.apache.arrow.adapter.avro.producers.logical.AvroTimestampSecProducer;
6364
import org.apache.arrow.adapter.avro.producers.logical.AvroTimestampSecTzProducer;
6465
import org.apache.arrow.util.Preconditions;
66+
import org.apache.arrow.vector.BaseIntVector;
6567
import org.apache.arrow.vector.BigIntVector;
6668
import org.apache.arrow.vector.BitVector;
6769
import org.apache.arrow.vector.DateDayVector;
@@ -100,7 +102,6 @@
100102
import org.apache.arrow.vector.complex.MapVector;
101103
import org.apache.arrow.vector.complex.StructVector;
102104
import org.apache.arrow.vector.dictionary.Dictionary;
103-
import org.apache.arrow.vector.dictionary.DictionaryEncoder;
104105
import org.apache.arrow.vector.dictionary.DictionaryProvider;
105106
import org.apache.arrow.vector.types.FloatingPointPrecision;
106107
import org.apache.arrow.vector.types.TimeUnit;
@@ -331,12 +332,8 @@ private static <T> T buildBaseTypeSchema(
331332
String[] symbols = dictionarySymbols(dictionary);
332333
return builder.enumeration(field.getName()).symbols(symbols);
333334
} else {
334-
Field decodedField =
335-
new Field(
336-
field.getName(),
337-
dictionary.getVector().getField().getFieldType(),
338-
dictionary.getVector().getField().getChildren());
339-
return buildBaseTypeSchema(builder, decodedField, namespace, dictionaries);
335+
throw new IllegalArgumentException(
336+
"Dictionary-encoded field is not a valid enum: " + field.getName());
340337
}
341338
}
342339

@@ -449,7 +446,8 @@ private static LogicalType timestampLogicalType(ArrowType.Timestamp timestampTyp
449446

450447
private static boolean dictionaryIsValidEnum(Dictionary dictionary) {
451448

452-
if (dictionary.getVectorType().getTypeID() != ArrowType.ArrowTypeID.Utf8) {
449+
if (dictionary.getVectorType().getTypeID() != ArrowType.ArrowTypeID.Utf8
450+
|| !(dictionary.getVector() instanceof VarCharVector)) {
453451
return false;
454452
}
455453

@@ -502,33 +500,24 @@ private static String[] dictionarySymbols(Dictionary dictionary) {
502500
* @param vectors The vectors that will be used to produce Avro data
503501
* @return The resulting composite Avro producer
504502
*/
505-
public static CompositeAvroProducer createCompositeProducer(
506-
List<FieldVector> vectors, DictionaryProvider dictionaries) {
503+
public static CompositeAvroProducer createCompositeProducer(List<FieldVector> vectors) {
507504

508505
List<Producer<? extends FieldVector>> producers = new ArrayList<>(vectors.size());
509506

510507
for (FieldVector vector : vectors) {
511-
BaseAvroProducer<? extends FieldVector> producer = createProducer(vector, dictionaries);
508+
BaseAvroProducer<? extends FieldVector> producer = createProducer(vector);
512509
producers.add(producer);
513510
}
514511

515512
return new CompositeAvroProducer(producers);
516513
}
517514

518-
/** Overload provided for convenience, sets dictionaries = null. */
519-
public static CompositeAvroProducer createCompositeProducer(List<FieldVector> vectors) {
520-
521-
return createCompositeProducer(vectors, null);
522-
}
523-
524-
private static BaseAvroProducer<?> createProducer(
525-
FieldVector vector, DictionaryProvider dictionaries) {
515+
private static BaseAvroProducer<?> createProducer(FieldVector vector) {
526516
boolean nullable = vector.getField().isNullable();
527-
return createProducer(vector, nullable, dictionaries);
517+
return createProducer(vector, nullable);
528518
}
529519

530-
private static BaseAvroProducer<?> createProducer(
531-
FieldVector vector, boolean nullable, DictionaryProvider dictionaries) {
520+
private static BaseAvroProducer<?> createProducer(FieldVector vector, boolean nullable) {
532521

533522
Preconditions.checkNotNull(vector, "Arrow vector object can't be null");
534523

@@ -537,30 +526,13 @@ private static BaseAvroProducer<?> createProducer(
537526
// Avro understands nullable types as a union of type | null
538527
// Most nullable fields in a VSR will not be unions, so provide a special wrapper
539528
if (nullable && minorType != Types.MinorType.UNION) {
540-
final BaseAvroProducer<?> innerProducer = createProducer(vector, false, dictionaries);
529+
final BaseAvroProducer<?> innerProducer = createProducer(vector, false);
541530
return new AvroNullableProducer<>(innerProducer);
542531
}
543532

544533
if (vector.getField().getDictionary() != null) {
545-
if (dictionaries == null) {
546-
throw new IllegalArgumentException(
547-
"Field references a dictionary but no dictionaries were provided: "
548-
+ vector.getField().getName());
549-
}
550-
Dictionary dictionary = dictionaries.lookup(vector.getField().getDictionary().getId());
551-
if (dictionary == null) {
552-
throw new IllegalArgumentException(
553-
"Field references a dictionary that does not exist: "
554-
+ vector.getField().getName()
555-
+ ", dictionary ID = "
556-
+ vector.getField().getDictionary().getId());
557-
}
558-
// If a field is dictionary-encoded but cannot be represented as an Avro enum,
559-
// then decode it before writing
560-
if (!dictionaryIsValidEnum(dictionary)) {
561-
FieldVector decodedVector = (FieldVector) DictionaryEncoder.decode(vector, dictionary);
562-
return createProducer(decodedVector, nullable, dictionaries);
563-
}
534+
BaseIntVector dictEncodedVector = (BaseIntVector) vector;
535+
return new AvroEnumProducer(dictEncodedVector);
564536
}
565537

566538
switch (minorType) {
@@ -640,23 +612,21 @@ private static BaseAvroProducer<?> createProducer(
640612
Producer<?>[] childProducers = new Producer<?>[childVectors.size()];
641613
for (int i = 0; i < childVectors.size(); i++) {
642614
FieldVector childVector = childVectors.get(i);
643-
childProducers[i] =
644-
createProducer(childVector, childVector.getField().isNullable(), dictionaries);
615+
childProducers[i] = createProducer(childVector, childVector.getField().isNullable());
645616
}
646617
return new AvroStructProducer(structVector, childProducers);
647618

648619
case LIST:
649620
ListVector listVector = (ListVector) vector;
650621
FieldVector itemVector = listVector.getDataVector();
651-
Producer<?> itemProducer =
652-
createProducer(itemVector, itemVector.getField().isNullable(), dictionaries);
622+
Producer<?> itemProducer = createProducer(itemVector, itemVector.getField().isNullable());
653623
return new AvroListProducer(listVector, itemProducer);
654624

655625
case FIXED_SIZE_LIST:
656626
FixedSizeListVector fixedListVector = (FixedSizeListVector) vector;
657627
FieldVector fixedItemVector = fixedListVector.getDataVector();
658628
Producer<?> fixedItemProducer =
659-
createProducer(fixedItemVector, fixedItemVector.getField().isNullable(), dictionaries);
629+
createProducer(fixedItemVector, fixedItemVector.getField().isNullable());
660630
return new AvroFixedSizeListProducer(fixedListVector, fixedItemProducer);
661631

662632
case MAP:
@@ -670,7 +640,7 @@ private static BaseAvroProducer<?> createProducer(
670640
FieldVector valueVector = entryVector.getChildrenFromFields().get(1);
671641
Producer<?> keyProducer = new AvroStringProducer(keyVector);
672642
Producer<?> valueProducer =
673-
createProducer(valueVector, valueVector.getField().isNullable(), dictionaries);
643+
createProducer(valueVector, valueVector.getField().isNullable());
674644
Producer<?> entryProducer =
675645
new AvroStructProducer(entryVector, new Producer<?>[] {keyProducer, valueProducer});
676646
return new AvroMapProducer(mapVector, entryProducer);

adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,22 @@
1717
package org.apache.arrow.adapter.avro.producers;
1818

1919
import java.io.IOException;
20-
import org.apache.arrow.vector.IntVector;
20+
import org.apache.arrow.vector.BaseIntVector;
2121
import org.apache.avro.io.Encoder;
2222

2323
/**
24-
* Producer that produces enum values from a dictionary-encoded {@link IntVector}, writes data to an
25-
* Avro encoder.
24+
* Producer that produces enum values from a dictionary-encoded {@link BaseIntVector}, writes data
25+
* to an Avro encoder.
2626
*/
27-
public class AvroEnumProducer extends BaseAvroProducer<IntVector> {
27+
public class AvroEnumProducer extends BaseAvroProducer<BaseIntVector> {
2828

2929
/** Instantiate an AvroEnumProducer. */
30-
public AvroEnumProducer(IntVector vector) {
30+
public AvroEnumProducer(BaseIntVector vector) {
3131
super(vector);
3232
}
3333

3434
@Override
3535
public void produce(Encoder encoder) throws IOException {
36-
encoder.writeEnum(vector.get(currentIndex++));
36+
encoder.writeEnum((int) vector.getValueAsLong(currentIndex++));
3737
}
3838
}

adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java

Lines changed: 1 addition & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -2872,8 +2872,7 @@ public void testWriteDictEnumEncoded() throws Exception {
28722872
// Write an AVRO block using the producer classes
28732873
try (FileOutputStream fos = new FileOutputStream(dataFile)) {
28742874
BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null);
2875-
CompositeAvroProducer producer =
2876-
ArrowToAvroUtils.createCompositeProducer(vectors, dictionaries);
2875+
CompositeAvroProducer producer = ArrowToAvroUtils.createCompositeProducer(vectors);
28772876
for (int row = 0; row < rowCount; row++) {
28782877
producer.produce(encoder);
28792878
}
@@ -2898,107 +2897,4 @@ record = datumReader.read(record, decoder);
28982897
}
28992898
}
29002899
}
2901-
2902-
@Test
2903-
public void testWriteEnumDecoded() throws Exception {
2904-
2905-
// Dict encoded fields that are not valid Avro enums should be decoded on write
2906-
2907-
BufferAllocator allocator = new RootAllocator();
2908-
2909-
// Create a dictionary
2910-
FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null);
2911-
VarCharVector dictionaryVector =
2912-
new VarCharVector(new Field("dictionary", dictionaryField, null), allocator);
2913-
2914-
dictionaryVector.allocateNew(3);
2915-
dictionaryVector.set(0, "passion fruit".getBytes()); // spaced not allowed
2916-
dictionaryVector.set(1, "banana".getBytes());
2917-
dictionaryVector.set(2, "cherry".getBytes());
2918-
dictionaryVector.setValueCount(3);
2919-
2920-
Dictionary dictionary =
2921-
new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));
2922-
2923-
FieldType dictionaryField2 = new FieldType(false, new ArrowType.Int(64, true), null);
2924-
BigIntVector dictionaryVector2 =
2925-
new BigIntVector(new Field("dictionary2", dictionaryField2, null), allocator);
2926-
2927-
dictionaryVector2.allocateNew(3);
2928-
dictionaryVector2.set(0, 0L);
2929-
dictionaryVector2.set(1, 1L);
2930-
dictionaryVector2.set(2, 2L);
2931-
dictionaryVector2.setValueCount(3);
2932-
2933-
Dictionary dictionary2 =
2934-
new Dictionary(dictionaryVector2, new DictionaryEncoding(2L, false, null));
2935-
2936-
DictionaryProvider dictionaries =
2937-
new DictionaryProvider.MapDictionaryProvider(dictionary, dictionary2);
2938-
2939-
// Field definition
2940-
FieldType stringField = new FieldType(false, new ArrowType.Utf8(), null);
2941-
VarCharVector stringVector =
2942-
new VarCharVector(new Field("enumField", stringField, null), allocator);
2943-
stringVector.allocateNew(10);
2944-
stringVector.setSafe(0, "passion fruit".getBytes());
2945-
stringVector.setSafe(1, "banana".getBytes());
2946-
stringVector.setSafe(2, "cherry".getBytes());
2947-
stringVector.setSafe(3, "cherry".getBytes());
2948-
stringVector.setSafe(4, "passion fruit".getBytes());
2949-
stringVector.setSafe(5, "banana".getBytes());
2950-
stringVector.setSafe(6, "passion fruit".getBytes());
2951-
stringVector.setSafe(7, "cherry".getBytes());
2952-
stringVector.setSafe(8, "banana".getBytes());
2953-
stringVector.setSafe(9, "passion fruit".getBytes());
2954-
stringVector.setValueCount(10);
2955-
2956-
FieldType longField = new FieldType(false, new ArrowType.Int(64, true), null);
2957-
BigIntVector longVector = new BigIntVector(new Field("enumField2", longField, null), allocator);
2958-
longVector.allocateNew(10);
2959-
for (int i = 0; i < 10; i++) {
2960-
longVector.setSafe(i, (long) i % 3);
2961-
}
2962-
longVector.setValueCount(10);
2963-
2964-
IntVector encodedVector = (IntVector) DictionaryEncoder.encode(stringVector, dictionary);
2965-
IntVector encodedVector2 = (IntVector) DictionaryEncoder.encode(longVector, dictionary2);
2966-
2967-
// Set up VSR
2968-
List<FieldVector> vectors = Arrays.asList(encodedVector, encodedVector2);
2969-
int rowCount = 10;
2970-
2971-
try (VectorSchemaRoot root = new VectorSchemaRoot(vectors)) {
2972-
2973-
File dataFile = new File(TMP, "testWriteEnumDecodedavro");
2974-
2975-
// Write an AVRO block using the producer classes
2976-
try (FileOutputStream fos = new FileOutputStream(dataFile)) {
2977-
BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null);
2978-
CompositeAvroProducer producer =
2979-
ArrowToAvroUtils.createCompositeProducer(vectors, dictionaries);
2980-
for (int row = 0; row < rowCount; row++) {
2981-
producer.produce(encoder);
2982-
}
2983-
encoder.flush();
2984-
}
2985-
2986-
// Set up reading the AVRO block as a GenericRecord
2987-
Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields(), dictionaries);
2988-
GenericDatumReader<GenericRecord> datumReader = new GenericDatumReader<>(schema);
2989-
2990-
try (InputStream inputStream = new FileInputStream(dataFile)) {
2991-
2992-
BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(inputStream, null);
2993-
GenericRecord record = null;
2994-
2995-
// Read and check values
2996-
for (int row = 0; row < rowCount; row++) {
2997-
record = datumReader.read(record, decoder);
2998-
assertEquals(stringVector.getObject(row).toString(), record.get("enumField").toString());
2999-
assertEquals(longVector.getObject(row), record.get("enumField2"));
3000-
}
3001-
}
3002-
}
3003-
}
30042900
}

adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
package org.apache.arrow.adapter.avro;
1818

1919
import static org.junit.jupiter.api.Assertions.assertEquals;
20+
import static org.junit.jupiter.api.Assertions.assertThrows;
2021

2122
import java.util.Arrays;
2223
import java.util.List;
2324
import org.apache.arrow.memory.BufferAllocator;
2425
import org.apache.arrow.memory.RootAllocator;
26+
import org.apache.arrow.vector.BigIntVector;
2527
import org.apache.arrow.vector.VarCharVector;
2628
import org.apache.arrow.vector.dictionary.Dictionary;
2729
import org.apache.arrow.vector.dictionary.DictionaryProvider;
@@ -1437,4 +1439,70 @@ public void testWriteDictEnumEncoded() {
14371439
assertEquals("banana", enumField.schema().getEnumSymbols().get(1));
14381440
assertEquals("cherry", enumField.schema().getEnumSymbols().get(2));
14391441
}
1442+
1443+
@Test
1444+
public void testWriteDictEnumInvalid() {
1445+
1446+
BufferAllocator allocator = new RootAllocator();
1447+
1448+
// Create a dictionary
1449+
FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null);
1450+
VarCharVector dictionaryVector =
1451+
new VarCharVector(new Field("dictionary", dictionaryField, null), allocator);
1452+
1453+
dictionaryVector.allocateNew(3);
1454+
dictionaryVector.set(0, "passion fruit".getBytes());
1455+
dictionaryVector.set(1, "banana".getBytes());
1456+
dictionaryVector.set(2, "cherry".getBytes());
1457+
dictionaryVector.setValueCount(3);
1458+
1459+
Dictionary dictionary =
1460+
new Dictionary(
1461+
dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true)));
1462+
DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);
1463+
1464+
List<Field> fields =
1465+
Arrays.asList(
1466+
new Field(
1467+
"enumField",
1468+
new FieldType(false, new ArrowType.Int(8, true), dictionary.getEncoding(), null),
1469+
null));
1470+
1471+
assertThrows(
1472+
IllegalArgumentException.class,
1473+
() -> ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries));
1474+
}
1475+
1476+
@Test
1477+
public void testWriteDictEnumInvalid2() {
1478+
1479+
BufferAllocator allocator = new RootAllocator();
1480+
1481+
// Create a dictionary
1482+
FieldType dictionaryField = new FieldType(false, new ArrowType.Int(64, true), null);
1483+
BigIntVector dictionaryVector =
1484+
new BigIntVector(new Field("dictionary", dictionaryField, null), allocator);
1485+
1486+
dictionaryVector.allocateNew(3);
1487+
dictionaryVector.set(0, 123L);
1488+
dictionaryVector.set(1, 456L);
1489+
dictionaryVector.set(2, 789L);
1490+
dictionaryVector.setValueCount(3);
1491+
1492+
Dictionary dictionary =
1493+
new Dictionary(
1494+
dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true)));
1495+
DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);
1496+
1497+
List<Field> fields =
1498+
Arrays.asList(
1499+
new Field(
1500+
"enumField",
1501+
new FieldType(false, new ArrowType.Int(8, true), dictionary.getEncoding(), null),
1502+
null));
1503+
1504+
assertThrows(
1505+
IllegalArgumentException.class,
1506+
() -> ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries));
1507+
}
14401508
}

adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripDataTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ private static void roundTripTest(
120120
try (FileOutputStream fos = new FileOutputStream(dataFile)) {
121121
BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null);
122122
CompositeAvroProducer producer =
123-
ArrowToAvroUtils.createCompositeProducer(root.getFieldVectors(), dictionaries);
123+
ArrowToAvroUtils.createCompositeProducer(root.getFieldVectors());
124124
for (int row = 0; row < rowCount; row++) {
125125
producer.produce(encoder);
126126
}

0 commit comments

Comments
 (0)