Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ private static VariantValue truncateUpperBound(VariantValue value) {
}
}

private static class ParquetSchemaProducer extends VariantVisitor<Type> {
static class ParquetSchemaProducer extends VariantVisitor<Type> {
@Override
public Type object(VariantObject object, List<String> names, List<Type> typedValues) {
if (object.numFields() < 1) {
Expand Down Expand Up @@ -492,7 +492,7 @@ public Type primitive(VariantPrimitive<?> primitive) {
throw new UnsupportedOperationException("Unsupported shredding type: " + primitive.type());
}

private static GroupType objectFields(List<GroupType> fields) {
static GroupType objectFields(List<GroupType> fields) {
Types.GroupBuilder<GroupType> builder = Types.buildGroup(Type.Repetition.OPTIONAL);
for (GroupType field : fields) {
checkField(field);
Expand All @@ -502,14 +502,14 @@ private static GroupType objectFields(List<GroupType> fields) {
return builder.named("typed_value");
}

private static void checkField(GroupType fieldType) {
static void checkField(GroupType fieldType) {
Preconditions.checkArgument(
fieldType.isRepetition(Type.Repetition.REQUIRED),
"Invalid field type repetition: %s should be REQUIRED",
fieldType.getRepetition());
}

private static GroupType field(String name, Type shreddedType) {
static GroupType field(String name, Type shreddedType) {
Types.GroupBuilder<GroupType> builder =
Types.buildGroup(Type.Repetition.REQUIRED)
.optional(PrimitiveType.PrimitiveTypeName.BINARY)
Expand All @@ -523,7 +523,7 @@ private static GroupType field(String name, Type shreddedType) {
return builder.named(name);
}

private static void checkShreddedType(Type shreddedType) {
static void checkShreddedType(Type shreddedType) {
Preconditions.checkArgument(
shreddedType.getName().equals("typed_value"),
"Invalid shredded type name: %s should be typed_value",
Expand All @@ -534,16 +534,16 @@ private static void checkShreddedType(Type shreddedType) {
shreddedType.getRepetition());
}

private static Type shreddedPrimitive(PrimitiveType.PrimitiveTypeName primitive) {
static Type shreddedPrimitive(PrimitiveType.PrimitiveTypeName primitive) {
return Types.optional(primitive).named("typed_value");
}

private static Type shreddedPrimitive(
static Type shreddedPrimitive(
PrimitiveType.PrimitiveTypeName primitive, LogicalTypeAnnotation annotation) {
return Types.optional(primitive).as(annotation).named("typed_value");
}

private static Type shreddedPrimitive(
static Type shreddedPrimitive(
PrimitiveType.PrimitiveTypeName primitive, LogicalTypeAnnotation annotation, int length) {
return Types.optional(primitive).as(annotation).length(length).named("typed_value");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
*/
package org.apache.iceberg.parquet;

import static org.apache.iceberg.parquet.ParquetVariantUtil.ParquetSchemaProducer.checkField;
import static org.apache.iceberg.parquet.ParquetVariantUtil.ParquetSchemaProducer.field;
import static org.apache.iceberg.parquet.ParquetVariantUtil.ParquetSchemaProducer.shreddedPrimitive;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

Expand Down Expand Up @@ -64,7 +67,6 @@
import org.apache.parquet.hadoop.api.WriteSupport;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName;
import org.apache.parquet.schema.Type;
Expand Down Expand Up @@ -202,7 +204,7 @@ public void testShreddedVariantPrimitives(VariantPrimitive<?> primitive) throws
.as("Null is not a shredded type")
.isTrue();

GroupType variantType = variant("var", 2, shreddedType(primitive));
GroupType variantType = variant("var", 2, ParquetVariantUtil.toParquetSchema(primitive));
MessageType parquetSchema = parquetSchema(variantType);

GenericRecord variant =
Expand Down Expand Up @@ -1454,73 +1456,6 @@ private static void checkShreddedType(Type shreddedType) {
shreddedType.getRepetition());
}

private static Type shreddedPrimitive(PrimitiveTypeName primitive) {
return Types.optional(primitive).named("typed_value");
}

private static Type shreddedPrimitive(
PrimitiveTypeName primitive, LogicalTypeAnnotation annotation) {
return Types.optional(primitive).as(annotation).named("typed_value");
}

private static Type shreddedType(VariantValue value) {
switch (value.type()) {
case BOOLEAN_TRUE:
case BOOLEAN_FALSE:
return shreddedPrimitive(PrimitiveTypeName.BOOLEAN);
case INT8:
return shreddedPrimitive(PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8));
case INT16:
return shreddedPrimitive(PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16));
case INT32:
return shreddedPrimitive(PrimitiveTypeName.INT32);
case INT64:
return shreddedPrimitive(PrimitiveTypeName.INT64);
case FLOAT:
return shreddedPrimitive(PrimitiveTypeName.FLOAT);
case DOUBLE:
return shreddedPrimitive(PrimitiveTypeName.DOUBLE);
case DECIMAL4:
BigDecimal decimal4 = (BigDecimal) value.asPrimitive().get();
return shreddedPrimitive(
PrimitiveTypeName.INT32, LogicalTypeAnnotation.decimalType(decimal4.scale(), 9));
case DECIMAL8:
BigDecimal decimal8 = (BigDecimal) value.asPrimitive().get();
return shreddedPrimitive(
PrimitiveTypeName.INT64, LogicalTypeAnnotation.decimalType(decimal8.scale(), 18));
case DECIMAL16:
BigDecimal decimal16 = (BigDecimal) value.asPrimitive().get();
return shreddedPrimitive(
PrimitiveTypeName.BINARY, LogicalTypeAnnotation.decimalType(decimal16.scale(), 38));
case DATE:
return shreddedPrimitive(PrimitiveTypeName.INT32, LogicalTypeAnnotation.dateType());
case TIMESTAMPTZ:
return shreddedPrimitive(
PrimitiveTypeName.INT64, LogicalTypeAnnotation.timestampType(true, TimeUnit.MICROS));
case TIMESTAMPNTZ:
return shreddedPrimitive(
PrimitiveTypeName.INT64, LogicalTypeAnnotation.timestampType(false, TimeUnit.MICROS));
case BINARY:
return shreddedPrimitive(PrimitiveTypeName.BINARY);
case STRING:
return shreddedPrimitive(PrimitiveTypeName.BINARY, STRING);
case TIME:
return shreddedPrimitive(
PrimitiveTypeName.INT64, LogicalTypeAnnotation.timeType(false, TimeUnit.MICROS));
case TIMESTAMPTZ_NANOS:
return shreddedPrimitive(
PrimitiveTypeName.INT64, LogicalTypeAnnotation.timestampType(true, TimeUnit.NANOS));
case TIMESTAMPNTZ_NANOS:
return shreddedPrimitive(
PrimitiveTypeName.INT64, LogicalTypeAnnotation.timestampType(false, TimeUnit.NANOS));
case UUID:
return shreddedPrimitive(
PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, LogicalTypeAnnotation.uuidType());
}

throw new UnsupportedOperationException("Unsupported shredding type: " + value.type());
}

private static Object toAvroValue(VariantPrimitive<?> variant) {
switch (variant.type()) {
case DECIMAL4:
Expand All @@ -1546,13 +1481,6 @@ private static GroupType variant(String name, int fieldId, Type shreddedType) {
.named(name);
}

private static void checkField(GroupType fieldType) {
Preconditions.checkArgument(
fieldType.isRepetition(Type.Repetition.REQUIRED),
"Invalid field type repetition: %s should be REQUIRED",
fieldType.getRepetition());
}

private static GroupType objectFields(GroupType... fields) {
for (GroupType fieldType : fields) {
checkField(fieldType);
Expand All @@ -1561,15 +1489,6 @@ private static GroupType objectFields(GroupType... fields) {
return Types.buildGroup(Type.Repetition.OPTIONAL).addFields(fields).named("typed_value");
}

private static GroupType field(String name, Type shreddedType) {
checkShreddedType(shreddedType);
return Types.buildGroup(Type.Repetition.REQUIRED)
.optional(PrimitiveTypeName.BINARY)
.named("value")
.addField(shreddedType)
.named(name);
}

private static GroupType element(Type shreddedType) {
return field("element", shreddedType);
}
Expand Down