Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions core/src/main/java/org/apache/iceberg/avro/ValueReaders.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.apache.avro.io.ResolvingDecoder;
import org.apache.avro.util.Utf8;
import org.apache.iceberg.common.DynConstructors;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;

import static java.util.Collections.emptyIterator;
Expand Down Expand Up @@ -580,10 +579,9 @@ protected StructReader(List<ValueReader<?>> readers, Types.StructType struct, Ma
List<Object> constantList = Lists.newArrayListWithCapacity(fields.size());
for (int pos = 0; pos < fields.size(); pos += 1) {
Types.NestedField field = fields.get(pos);
Object constant = idToConstant.get(field.fieldId());
if (constant != null) {
if (idToConstant.containsKey(field.fieldId())) {
positionList.add(pos);
constantList.add(prepareConstant(field.type(), constant));
constantList.add(idToConstant.get(field.fieldId()));
}
}

Expand All @@ -597,10 +595,6 @@ protected StructReader(List<ValueReader<?>> readers, Types.StructType struct, Ma

protected abstract void set(S struct, int pos, Object value);

protected Object prepareConstant(Type type, Object value) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm moving this out of Avro and adding a callback to convert the constants to PartitionUtil.constantsMap. That way, Spark can supply a conversion function and use it in both places, instead of duplicating the conversion in Avro and Parquet readers.

return value;
}

public ValueReader<?> reader(int pos) {
return readers[pos];
}
Expand Down
16 changes: 13 additions & 3 deletions core/src/main/java/org/apache/iceberg/util/PartitionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,36 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import org.apache.iceberg.FileScanTask;
import org.apache.iceberg.PartitionField;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.StructLike;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;

public class PartitionUtil {
private PartitionUtil() {
}

public static Map<Integer, ?> constantsMap(FileScanTask task) {
return constantsMap(task.spec(), task.file().partition());
return constantsMap(task, (type, constant) -> constant);
}

private static Map<Integer, ?> constantsMap(PartitionSpec spec, StructLike partitionData) {
public static Map<Integer, ?> constantsMap(FileScanTask task, BiFunction<Type, Object, Object> convertConstant) {
return constantsMap(task.spec(), task.file().partition(), convertConstant);
}

private static Map<Integer, ?> constantsMap(PartitionSpec spec, StructLike partitionData,
BiFunction<Type, Object, Object> convertConstant) {
// use java.util.HashMap because partition data may contain null values
Map<Integer, Object> idToConstant = new HashMap<>();
List<Types.NestedField> partitionFields = spec.partitionType().fields();
List<PartitionField> fields = spec.fields();
for (int pos = 0; pos < fields.size(); pos += 1) {
PartitionField field = fields.get(pos);
idToConstant.put(field.sourceId(), partitionData.get(pos, Object.class));
Object converted = convertConstant.apply(partitionFields.get(pos).type(), partitionData.get(pos, Object.class));
idToConstant.put(field.sourceId(), converted);
}
return idToConstant;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.iceberg.spark.data;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.math.BigDecimal;
Expand Down Expand Up @@ -66,23 +67,29 @@ public class SparkParquetReaders {
private SparkParquetReaders() {
}

@SuppressWarnings("unchecked")
public static ParquetValueReader<InternalRow> buildReader(Schema expectedSchema,
MessageType fileSchema) {
return buildReader(expectedSchema, fileSchema, ImmutableMap.of());
}

@SuppressWarnings("unchecked")
public static ParquetValueReader<InternalRow> buildReader(Schema expectedSchema,
MessageType fileSchema,
Map<Integer, ?> idToConstant) {
if (ParquetSchemaUtil.hasIds(fileSchema)) {
return (ParquetValueReader<InternalRow>)
TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema,
new ReadBuilder(fileSchema));
new ReadBuilder(fileSchema, idToConstant));
} else {
return (ParquetValueReader<InternalRow>)
TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema,
new FallbackReadBuilder(fileSchema));
new FallbackReadBuilder(fileSchema, idToConstant));
}
}

private static class FallbackReadBuilder extends ReadBuilder {
FallbackReadBuilder(MessageType type) {
super(type);
FallbackReadBuilder(MessageType type, Map<Integer, ?> idToConstant) {
super(type, idToConstant);
}

@Override
Expand Down Expand Up @@ -113,9 +120,11 @@ public ParquetValueReader<?> struct(Types.StructType ignored, GroupType struct,

private static class ReadBuilder extends TypeWithSchemaVisitor<ParquetValueReader<?>> {
private final MessageType type;
private final Map<Integer, ?> idToConstant;

ReadBuilder(MessageType type) {
ReadBuilder(MessageType type, Map<Integer, ?> idToConstant) {
this.type = type;
this.idToConstant = idToConstant;
}

@Override
Expand Down Expand Up @@ -146,13 +155,19 @@ public ParquetValueReader<?> struct(Types.StructType expected, GroupType struct,
List<Type> types = Lists.newArrayListWithExpectedSize(expectedFields.size());
for (Types.NestedField field : expectedFields) {
int id = field.fieldId();
ParquetValueReader<?> reader = readersById.get(id);
if (reader != null) {
reorderedFields.add(reader);
types.add(typesById.get(id));
} else {
reorderedFields.add(ParquetValueReaders.nulls());
if (idToConstant.containsKey(id)) {
// containsKey is used because the constant may be null
reorderedFields.add(ParquetValueReaders.constant(idToConstant.get(id)));
types.add(null);
} else {
ParquetValueReader<?> reader = readersById.get(id);
if (reader != null) {
reorderedFields.add(reader);
types.add(typesById.get(id));
} else {
reorderedFields.add(ParquetValueReaders.nulls());
types.add(null);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,11 @@
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.apache.avro.generic.GenericData;
import org.apache.avro.io.Decoder;
import org.apache.avro.util.Utf8;
import org.apache.iceberg.avro.ValueReader;
import org.apache.iceberg.avro.ValueReaders;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.ByteBuffers;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData;
Expand Down Expand Up @@ -287,30 +284,5 @@ protected void set(InternalRow struct, int pos, Object value) {
struct.setNullAt(pos);
}
}

@Override
protected Object prepareConstant(Type type, Object value) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved into Spark.

switch (type.typeId()) {
case DECIMAL:
return Decimal.apply((BigDecimal) value);
case STRING:
if (value instanceof Utf8) {
Utf8 utf8 = (Utf8) value;
return UTF8String.fromBytes(utf8.getBytes(), 0, utf8.getByteLength());
}
return UTF8String.fromString(value.toString());
case FIXED:
if (value instanceof byte[]) {
return value;
} else if (value instanceof GenericData.Fixed) {
return ((GenericData.Fixed) value).bytes();
}
return ByteBuffers.toByteArray((ByteBuffer) value);
case BINARY:
return ByteBuffers.toByteArray((ByteBuffer) value);
default:
}
return value;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.avro.generic.GenericData;
import org.apache.avro.util.Utf8;
import org.apache.iceberg.CombinedScanTask;
import org.apache.iceberg.DataFile;
import org.apache.iceberg.DataTask;
Expand All @@ -47,20 +51,24 @@
import org.apache.iceberg.spark.data.SparkAvroReader;
import org.apache.iceberg.spark.data.SparkOrcReader;
import org.apache.iceberg.spark.data.SparkParquetReaders;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.TypeUtil;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.ByteBuffers;
import org.apache.iceberg.util.PartitionUtil;
import org.apache.spark.rdd.InputFileBlockHolder;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.Attribute;
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
import org.apache.spark.sql.catalyst.expressions.JoinedRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.types.UTF8String;
import scala.collection.JavaConverters;

class RowDataReader extends BaseDataReader<InternalRow> {
private static final Set<FileFormat> SUPPORTS_CONSTANTS = Sets.newHashSet(FileFormat.AVRO);
private static final Set<FileFormat> SUPPORTS_CONSTANTS = Sets.newHashSet(FileFormat.AVRO, FileFormat.PARQUET);
// for some reason, the apply method can't be called from Java without reflection
private static final DynMethods.UnboundMethod APPLY_PROJECTION = DynMethods.builder("apply")
.impl(UnsafeProjection.class, InternalRow.class)
Expand Down Expand Up @@ -103,7 +111,7 @@ Iterator<InternalRow> open(FileScanTask task) {
if (hasJoinedPartitionColumns) {
if (SUPPORTS_CONSTANTS.contains(file.format())) {
iterSchema = requiredSchema;
iter = open(task, requiredSchema, PartitionUtil.constantsMap(task));
iter = open(task, requiredSchema, PartitionUtil.constantsMap(task, RowDataReader::convertConstant));
} else {
// schema used to read data files
Schema readSchema = TypeUtil.selectNot(requiredSchema, idColumns);
Expand Down Expand Up @@ -144,7 +152,7 @@ private Iterator<InternalRow> open(FileScanTask task, Schema readSchema, Map<Int

switch (task.file().format()) {
case PARQUET:
iter = newParquetIterable(location, task, readSchema);
iter = newParquetIterable(location, task, readSchema, idToConstant);
break;

case AVRO:
Expand Down Expand Up @@ -182,11 +190,12 @@ private CloseableIterable<InternalRow> newAvroIterable(
private CloseableIterable<InternalRow> newParquetIterable(
InputFile location,
FileScanTask task,
Schema readSchema) {
Schema readSchema,
Map<Integer, ?> idToConstant) {
return Parquet.read(location)
.project(readSchema)
.split(task.start(), task.length())
.createReaderFunc(fileSchema -> SparkParquetReaders.buildReader(readSchema, fileSchema))
.createReaderFunc(fileSchema -> SparkParquetReaders.buildReader(readSchema, fileSchema, idToConstant))
.filter(task.residual())
.caseSensitive(caseSensitive)
.build();
Expand Down Expand Up @@ -233,4 +242,32 @@ private static UnsafeProjection projection(Schema finalSchema, Schema readSchema
JavaConverters.asScalaBufferConverter(exprs).asScala().toSeq(),
JavaConverters.asScalaBufferConverter(attrs).asScala().toSeq());
}

private static Object convertConstant(Type type, Object value) {
if (value == null) {
return null;
}

switch (type.typeId()) {
case DECIMAL:
return Decimal.apply((BigDecimal) value);
case STRING:
if (value instanceof Utf8) {
Utf8 utf8 = (Utf8) value;
return UTF8String.fromBytes(utf8.getBytes(), 0, utf8.getByteLength());
}
return UTF8String.fromString(value.toString());
case FIXED:
if (value instanceof byte[]) {
return value;
} else if (value instanceof GenericData.Fixed) {
return ((GenericData.Fixed) value).bytes();
}
return ByteBuffers.toByteArray((ByteBuffer) value);
case BINARY:
return ByteBuffers.toByteArray((ByteBuffer) value);
default:
}
return value;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.spark.sql.SparkSession;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -307,4 +308,72 @@ public void testPartitionValueTypes() throws Exception {
TestTables.clearTables();
}
}

@Test
public void testNestedPartitionValues() throws Exception {
Assume.assumeTrue("ORC can't project nested partition values", !format.equalsIgnoreCase("orc"));

String[] columnNames = new String[] {
"b", "i", "l", "f", "d", "date", "ts", "s", "bytes", "dec_9_0", "dec_11_2", "dec_38_10"
};

HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf());
Schema nestedSchema = new Schema(optional(1, "nested", SUPPORTED_PRIMITIVES.asStruct()));

// create a table around the source data
String sourceLocation = temp.newFolder("source_table").toString();
Table source = tables.create(nestedSchema, sourceLocation);

// write out an Avro data file with all of the data types for source data
List<GenericData.Record> expected = RandomData.generateList(source.schema(), 2, 128735L);
File avroData = temp.newFile("data.avro");
Assert.assertTrue(avroData.delete());
try (FileAppender<GenericData.Record> appender = Avro.write(Files.localOutput(avroData))
.schema(source.schema())
.build()) {
appender.addAll(expected);
}

// add the Avro data file to the source table
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not write the data for the parameterized format for which the test is running?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just source data for the write from Spark with the target format. Since it isn't part of the test, we don't want it to change at all in ways that might affect the test.

source.newAppend()
.appendFile(DataFiles.fromInputFile(Files.localInput(avroData), 10))
.commit();

Dataset<Row> sourceDF = spark.read().format("iceberg").load(sourceLocation);

try {
for (String column : columnNames) {
String desc = "partition_by_" + SUPPORTED_PRIMITIVES.findType(column).toString();

File parent = temp.newFolder(desc);
File location = new File(parent, "test");
File dataFolder = new File(location, "data");
Assert.assertTrue("mkdirs should succeed", dataFolder.mkdirs());

PartitionSpec spec = PartitionSpec.builderFor(nestedSchema).identity("nested." + column).build();

Table table = tables.create(nestedSchema, spec, location.toString());
table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit();

sourceDF.write()
.format("iceberg")
.mode("append")
.save(location.toString());

List<Row> actual = spark.read()
.format("iceberg")
.load(location.toString())
.collectAsList();

Assert.assertEquals("Number of rows should match", expected.size(), actual.size());

for (int i = 0; i < expected.size(); i += 1) {
TestHelpers.assertEqualsSafe(
nestedSchema.asStruct(), expected.get(i), actual.get(i));
}
}
} finally {
TestTables.clearTables();
}
}
}