Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,37 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import java.io.Closeable;
import java.io.IOException;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.iceberg.CombinedScanTask;
import org.apache.iceberg.DataFile;
import org.apache.iceberg.FileScanTask;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Schema;
import org.apache.iceberg.common.DynMethods;
import org.apache.iceberg.encryption.EncryptedFiles;
import org.apache.iceberg.encryption.EncryptionManager;
import org.apache.iceberg.io.CloseableIterable;
import org.apache.iceberg.io.FileIO;
import org.apache.iceberg.io.InputFile;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.Pair;
import org.apache.spark.rdd.InputFileBlockHolder;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.Attribute;
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.sources.v2.reader.InputPartitionReader;
import org.apache.spark.sql.types.StructType;
import scala.collection.JavaConverters;

/**
* Base class of readers of type {@link InputPartitionReader} to read data as objects of type @param <T>
Expand All @@ -44,6 +61,15 @@
*/
@SuppressWarnings("checkstyle:VisibilityModifier")
abstract class BaseDataReader<T> implements InputPartitionReader<T> {
// for some reason, the apply method can't be called from Java without reflection
static final DynMethods.UnboundMethod APPLY_PROJECTION = DynMethods.builder("apply")
.impl(UnsafeProjection.class, InternalRow.class)
.build();

final Schema tableSchema;
private final Schema expectedSchema;
final boolean caseSensitive;

private final Iterator<FileScanTask> tasks;
private final FileIO fileIo;
private final Map<String, InputFile> inputFiles;
Expand All @@ -52,7 +78,11 @@ abstract class BaseDataReader<T> implements InputPartitionReader<T> {
Closeable currentCloseable;
private T current = null;

BaseDataReader(CombinedScanTask task, FileIO fileIo, EncryptionManager encryptionManager) {
BaseDataReader(CombinedScanTask task, FileIO fileIo, EncryptionManager encryptionManager, Schema tableSchema,
Schema expectedSchema, boolean caseSensitive) {
this.tableSchema = tableSchema;
this.expectedSchema = expectedSchema;
this.caseSensitive = caseSensitive;
this.fileIo = fileIo;
this.tasks = task.files().iterator();
Iterable<InputFile> decryptedFiles = encryptionManager.decrypt(Iterables.transform(
Expand Down Expand Up @@ -88,7 +118,78 @@ public T get() {
return current;
}

abstract Iterator<T> open(FileScanTask task);
/**
* Return a {@link Pair} of {@link Schema} and {@link Iterator} over records of type T that include the identity
* partition columns being projected.
*/
abstract Pair<Schema, Iterator<T>> getJoinedSchemaAndIteratorWithIdentityPartition(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is part of the Base Readers api, would be good to add a docstring on what this is supposed to do and where it's being used.

DataFile file, FileScanTask task,
Schema requiredSchema, Set<Integer> idColumns, PartitionSpec spec);

abstract Iterator<T> open(FileScanTask task, Schema readSchema, Map<Integer, ?> idToConstant);

private Iterator<T> open(FileScanTask task) {
DataFile file = task.file();

// update the current file for Spark's filename() function
InputFileBlockHolder.set(file.path().toString(), task.start(), task.length());

// schema or rows returned by readers
Schema finalSchema = expectedSchema;
PartitionSpec spec = task.spec();
Set<Integer> idColumns = spec.identitySourceIds();

// schema needed for the projection and filtering
StructType sparkType = SparkSchemaUtil.convert(finalSchema);
Schema requiredSchema = SparkSchemaUtil.prune(tableSchema, sparkType, task.residual(), caseSensitive);
boolean hasJoinedPartitionColumns = !idColumns.isEmpty();
boolean hasExtraFilterColumns = requiredSchema.columns().size() != finalSchema.columns().size();

Schema iterSchema;
Iterator<T> iter;

if (hasJoinedPartitionColumns) {
Pair<Schema, Iterator<T>> pair = getJoinedSchemaAndIteratorWithIdentityPartition(file, task, requiredSchema,
idColumns, spec);
iterSchema = pair.first();
iter = pair.second();
} else if (hasExtraFilterColumns) {
// add projection to the final schema
iterSchema = requiredSchema;
iter = open(task, requiredSchema, ImmutableMap.of());
} else {
// return the base iterator
iterSchema = finalSchema;
iter = open(task, finalSchema, ImmutableMap.of());
}

// TODO: remove the projection by reporting the iterator's schema back to Spark
return Iterators.transform(
iter,
APPLY_PROJECTION.bind(projection(finalSchema, iterSchema))::invoke);
}

static UnsafeProjection projection(Schema finalSchema, Schema readSchema) {
StructType struct = SparkSchemaUtil.convert(readSchema);

List<AttributeReference> refs = JavaConverters.seqAsJavaListConverter(struct.toAttributes()).asJava();
List<Attribute> attrs = Lists.newArrayListWithExpectedSize(struct.fields().length);
List<org.apache.spark.sql.catalyst.expressions.Expression> exprs =
Lists.newArrayListWithExpectedSize(struct.fields().length);

for (AttributeReference ref : refs) {
attrs.add(ref.toAttribute());
}

for (Types.NestedField field : finalSchema.columns()) {
int indexInReadSchema = struct.fieldIndex(field.name());
exprs.add(refs.get(indexInReadSchema));
}

return UnsafeProjection.create(
JavaConverters.asScalaBufferConverter(exprs).asScala().toSeq(),
JavaConverters.asScalaBufferConverter(attrs).asScala().toSeq());
}

@Override
public void close() throws IOException {
Expand Down
114 changes: 21 additions & 93 deletions spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.avro.generic.GenericData;
Expand All @@ -40,109 +38,61 @@
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Schema;
import org.apache.iceberg.avro.Avro;
import org.apache.iceberg.common.DynMethods;
import org.apache.iceberg.encryption.EncryptionManager;
import org.apache.iceberg.io.CloseableIterable;
import org.apache.iceberg.io.FileIO;
import org.apache.iceberg.io.InputFile;
import org.apache.iceberg.orc.ORC;
import org.apache.iceberg.parquet.Parquet;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.spark.data.SparkAvroReader;
import org.apache.iceberg.spark.data.SparkOrcReader;
import org.apache.iceberg.spark.data.SparkParquetReaders;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.TypeUtil;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.ByteBuffers;
import org.apache.iceberg.util.Pair;
import org.apache.iceberg.util.PartitionUtil;
import org.apache.spark.rdd.InputFileBlockHolder;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.Attribute;
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
import org.apache.spark.sql.catalyst.expressions.JoinedRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.types.UTF8String;
import scala.collection.JavaConverters;

class RowDataReader extends BaseDataReader<InternalRow> {
private static final Set<FileFormat> SUPPORTS_CONSTANTS = Sets.newHashSet(FileFormat.AVRO, FileFormat.PARQUET);
// for some reason, the apply method can't be called from Java without reflection
private static final DynMethods.UnboundMethod APPLY_PROJECTION = DynMethods.builder("apply")
.impl(UnsafeProjection.class, InternalRow.class)
.build();

private final Schema tableSchema;
private final Schema expectedSchema;
private final boolean caseSensitive;

RowDataReader(
CombinedScanTask task, Schema tableSchema, Schema expectedSchema, FileIO fileIo,
EncryptionManager encryptionManager, boolean caseSensitive) {
super(task, fileIo, encryptionManager);
this.tableSchema = tableSchema;
this.expectedSchema = expectedSchema;
this.caseSensitive = caseSensitive;
super(task, fileIo, encryptionManager, tableSchema, expectedSchema, caseSensitive);
}

@Override
Iterator<InternalRow> open(FileScanTask task) {
DataFile file = task.file();

// update the current file for Spark's filename() function
InputFileBlockHolder.set(file.path().toString(), task.start(), task.length());

// schema or rows returned by readers
Schema finalSchema = expectedSchema;
PartitionSpec spec = task.spec();
Set<Integer> idColumns = spec.identitySourceIds();

// schema needed for the projection and filtering
StructType sparkType = SparkSchemaUtil.convert(finalSchema);
Schema requiredSchema = SparkSchemaUtil.prune(tableSchema, sparkType, task.residual(), caseSensitive);
boolean hasJoinedPartitionColumns = !idColumns.isEmpty();
boolean hasExtraFilterColumns = requiredSchema.columns().size() != finalSchema.columns().size();

Pair<Schema, Iterator<InternalRow>> getJoinedSchemaAndIteratorWithIdentityPartition(DataFile file, FileScanTask task,
Schema requiredSchema, Set<Integer> idColumns, PartitionSpec spec) {
Schema iterSchema;
Iterator<InternalRow> iter;

if (hasJoinedPartitionColumns) {
if (SUPPORTS_CONSTANTS.contains(file.format())) {
iterSchema = requiredSchema;
iter = open(task, requiredSchema, PartitionUtil.constantsMap(task, RowDataReader::convertConstant));
} else {
// schema used to read data files
Schema readSchema = TypeUtil.selectNot(requiredSchema, idColumns);
Schema partitionSchema = TypeUtil.select(requiredSchema, idColumns);
PartitionRowConverter convertToRow = new PartitionRowConverter(partitionSchema, spec);
JoinedRow joined = new JoinedRow();

InternalRow partition = convertToRow.apply(file.partition());
joined.withRight(partition);

// create joined rows and project from the joined schema to the final schema
iterSchema = TypeUtil.join(readSchema, partitionSchema);
iter = Iterators.transform(open(task, readSchema, ImmutableMap.of()), joined::withLeft);
}
} else if (hasExtraFilterColumns) {
// add projection to the final schema
if (SUPPORTS_CONSTANTS.contains(file.format())) {
iterSchema = requiredSchema;
iter = open(task, requiredSchema, ImmutableMap.of());
iter = open(task, requiredSchema, PartitionUtil.constantsMap(task, RowDataReader::convertConstant));
} else {
// return the base iterator
iterSchema = finalSchema;
iter = open(task, finalSchema, ImmutableMap.of());
// schema used to read data files
Schema readSchema = TypeUtil.selectNot(requiredSchema, idColumns);
Schema partitionSchema = TypeUtil.select(requiredSchema, idColumns);
PartitionRowConverter convertToRow = new PartitionRowConverter(partitionSchema, spec);
JoinedRow joined = new JoinedRow();

InternalRow partition = convertToRow.apply(file.partition());
joined.withRight(partition);

// create joined rows and project from the joined schema to the final schema
iterSchema = TypeUtil.join(readSchema, partitionSchema);
iter = Iterators.transform(open(task, readSchema, ImmutableMap.of()), joined::withLeft);
}

// TODO: remove the projection by reporting the iterator's schema back to Spark
return Iterators.transform(
iter,
APPLY_PROJECTION.bind(projection(finalSchema, iterSchema))::invoke);
return Pair.of(iterSchema, iter);
}

private Iterator<InternalRow> open(FileScanTask task, Schema readSchema, Map<Integer, ?> idToConstant) {
@Override
Iterator<InternalRow> open(FileScanTask task, Schema readSchema, Map<Integer, ?> idToConstant) {
CloseableIterable<InternalRow> iter;
if (task.isDataTask()) {
iter = newDataIterable(task.asDataTask(), readSchema);
Expand Down Expand Up @@ -221,28 +171,6 @@ private CloseableIterable<InternalRow> newDataIterable(DataTask task, Schema rea
asSparkRows, APPLY_PROJECTION.bind(projection(readSchema, tableSchema))::invoke);
}

private static UnsafeProjection projection(Schema finalSchema, Schema readSchema) {
StructType struct = SparkSchemaUtil.convert(readSchema);

List<AttributeReference> refs = JavaConverters.seqAsJavaListConverter(struct.toAttributes()).asJava();
List<Attribute> attrs = Lists.newArrayListWithExpectedSize(struct.fields().length);
List<org.apache.spark.sql.catalyst.expressions.Expression> exprs =
Lists.newArrayListWithExpectedSize(struct.fields().length);

for (AttributeReference ref : refs) {
attrs.add(ref.toAttribute());
}

for (Types.NestedField field : finalSchema.columns()) {
int indexInReadSchema = struct.fieldIndex(field.name());
exprs.add(refs.get(indexInReadSchema));
}

return UnsafeProjection.create(
JavaConverters.asScalaBufferConverter(exprs).asScala().toSeq(),
JavaConverters.asScalaBufferConverter(attrs).asScala().toSeq());
}

private static Object convertConstant(Type type, Object value) {
if (value == null) {
return null;
Expand Down