diff --git a/.gitignore b/.gitignore index 7b1437228856..af76d7cc0ea1 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,7 @@ sdist/ coverage.xml .pytest_cache/ spark/tmp/ +spark-warehouse/ spark/spark-warehouse/ spark2/spark-warehouse/ spark3/spark-warehouse/ diff --git a/api/src/main/java/org/apache/iceberg/io/CloseableIterator.java b/api/src/main/java/org/apache/iceberg/io/CloseableIterator.java index e2b2e4f809dd..079190b12c26 100644 --- a/api/src/main/java/org/apache/iceberg/io/CloseableIterator.java +++ b/api/src/main/java/org/apache/iceberg/io/CloseableIterator.java @@ -23,6 +23,8 @@ import java.io.IOException; import java.util.Collections; import java.util.Iterator; +import java.util.function.Function; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; public interface CloseableIterator extends Iterator, Closeable { @@ -54,4 +56,25 @@ public E next() { } }; } + + static CloseableIterator transform(CloseableIterator iterator, Function transform) { + Preconditions.checkNotNull(transform, "Cannot apply a null transform"); + + return new CloseableIterator() { + @Override + public void close() throws IOException { + iterator.close(); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public O next() { + return transform.apply(iterator.next()); + } + }; + } } diff --git a/orc/src/main/java/org/apache/iceberg/orc/ORC.java b/orc/src/main/java/org/apache/iceberg/orc/ORC.java index 9e0354967844..38dc522fa5cf 100644 --- a/orc/src/main/java/org/apache/iceberg/orc/ORC.java +++ b/orc/src/main/java/org/apache/iceberg/orc/ORC.java @@ -127,6 +127,8 @@ public static class ReadBuilder { private boolean caseSensitive = true; private Function> readerFunc; + private Function> batchedReaderFunc; + private int recordsPerBatch = VectorizedRowBatch.DEFAULT_SIZE; private ReadBuilder(InputFile file) { Preconditions.checkNotNull(file, "Input file cannot be null"); @@ -168,6 +170,8 @@ public ReadBuilder config(String property, String value) { } public ReadBuilder createReaderFunc(Function> readerFunction) { + Preconditions.checkArgument(this.batchedReaderFunc == null, + "Reader function cannot be set since the batched version is already set"); this.readerFunc = readerFunction; return this; } @@ -177,9 +181,22 @@ public ReadBuilder filter(Expression newFilter) { return this; } + public ReadBuilder createBatchedReaderFunc(Function> batchReaderFunction) { + Preconditions.checkArgument(this.readerFunc == null, + "Batched reader function cannot be set since the non-batched version is already set"); + this.batchedReaderFunc = batchReaderFunction; + return this; + } + + public ReadBuilder recordsPerBatch(int numRecordsPerBatch) { + this.recordsPerBatch = numRecordsPerBatch; + return this; + } + public CloseableIterable build() { Preconditions.checkNotNull(schema, "Schema is required"); - return new OrcIterable<>(file, conf, schema, start, length, readerFunc, caseSensitive, filter); + return new OrcIterable<>(file, conf, schema, start, length, readerFunc, caseSensitive, filter, batchedReaderFunc, + recordsPerBatch); } } diff --git a/orc/src/main/java/org/apache/iceberg/orc/OrcBatchReader.java b/orc/src/main/java/org/apache/iceberg/orc/OrcBatchReader.java new file mode 100644 index 000000000000..86dcc656e4f9 --- /dev/null +++ b/orc/src/main/java/org/apache/iceberg/orc/OrcBatchReader.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.orc; + +import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; + +/** + * Used for implementing ORC batch readers. + */ +@FunctionalInterface +public interface OrcBatchReader { + + /** + * Reads a row batch. + */ + T read(VectorizedRowBatch batch); + +} diff --git a/orc/src/main/java/org/apache/iceberg/orc/OrcIterable.java b/orc/src/main/java/org/apache/iceberg/orc/OrcIterable.java index 88a421a1b496..4dc29b11e59e 100644 --- a/orc/src/main/java/org/apache/iceberg/orc/OrcIterable.java +++ b/orc/src/main/java/org/apache/iceberg/orc/OrcIterable.java @@ -20,7 +20,6 @@ package org.apache.iceberg.orc; import java.io.IOException; -import java.util.Iterator; import java.util.function.Function; import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.Schema; @@ -49,10 +48,13 @@ class OrcIterable extends CloseableGroup implements CloseableIterable { private final Function> readerFunction; private final Expression filter; private final boolean caseSensitive; + private final Function> batchReaderFunction; + private final int recordsPerBatch; OrcIterable(InputFile file, Configuration config, Schema schema, Long start, Long length, - Function> readerFunction, boolean caseSensitive, Expression filter) { + Function> readerFunction, boolean caseSensitive, Expression filter, + Function> batchReaderFunction, int recordsPerBatch) { this.schema = schema; this.readerFunction = readerFunction; this.file = file; @@ -61,6 +63,8 @@ class OrcIterable extends CloseableGroup implements CloseableIterable { this.config = config; this.caseSensitive = caseSensitive; this.filter = (filter == Expressions.alwaysTrue()) ? null : filter; + this.batchReaderFunction = batchReaderFunction; + this.recordsPerBatch = recordsPerBatch; } @SuppressWarnings("unchecked") @@ -75,16 +79,22 @@ public CloseableIterator iterator() { Expression boundFilter = Binder.bind(schema.asStruct(), filter, caseSensitive); sarg = ExpressionToSearchArgument.convert(boundFilter, readOrcSchema); } - Iterator iterator = new OrcIterator( - newOrcIterator(file, readOrcSchema, start, length, orcFileReader, sarg), - readerFunction.apply(readOrcSchema)); - return CloseableIterator.withClose(iterator); + + VectorizedRowBatchIterator rowBatchIterator = newOrcIterator(file, readOrcSchema, start, length, orcFileReader, + sarg, recordsPerBatch); + if (batchReaderFunction != null) { + OrcBatchReader batchReader = (OrcBatchReader) batchReaderFunction.apply(readOrcSchema); + return CloseableIterator.transform(rowBatchIterator, batchReader::read); + } else { + return new OrcRowIterator<>(rowBatchIterator, (OrcRowReader) readerFunction.apply(readOrcSchema)); + } } private static VectorizedRowBatchIterator newOrcIterator(InputFile file, TypeDescription readerSchema, Long start, Long length, - Reader orcFileReader, SearchArgument sarg) { + Reader orcFileReader, SearchArgument sarg, + int recordsPerBatch) { final Reader.Options options = orcFileReader.options(); if (start != null) { options.range(start, length); @@ -93,13 +103,14 @@ private static VectorizedRowBatchIterator newOrcIterator(InputFile file, options.searchArgument(sarg, new String[]{}); try { - return new VectorizedRowBatchIterator(file.location(), readerSchema, orcFileReader.rows(options)); + return new VectorizedRowBatchIterator(file.location(), readerSchema, orcFileReader.rows(options), + recordsPerBatch); } catch (IOException ioe) { throw new RuntimeIOException(ioe, "Failed to get ORC rows for file: %s", file); } } - private static class OrcIterator implements Iterator { + private static class OrcRowIterator implements CloseableIterator { private int nextRow; private VectorizedRowBatch current; @@ -107,7 +118,7 @@ private static class OrcIterator implements Iterator { private final VectorizedRowBatchIterator batchIter; private final OrcRowReader reader; - OrcIterator(VectorizedRowBatchIterator batchIter, OrcRowReader reader) { + OrcRowIterator(VectorizedRowBatchIterator batchIter, OrcRowReader reader) { this.batchIter = batchIter; this.reader = reader; current = null; @@ -128,6 +139,10 @@ public T next() { return this.reader.read(current, nextRow++); } - } + @Override + public void close() throws IOException { + batchIter.close(); + } + } } diff --git a/orc/src/main/java/org/apache/iceberg/orc/VectorizedRowBatchIterator.java b/orc/src/main/java/org/apache/iceberg/orc/VectorizedRowBatchIterator.java index 125a37c39498..7f3abbf0a0a4 100644 --- a/orc/src/main/java/org/apache/iceberg/orc/VectorizedRowBatchIterator.java +++ b/orc/src/main/java/org/apache/iceberg/orc/VectorizedRowBatchIterator.java @@ -19,10 +19,9 @@ package org.apache.iceberg.orc; -import java.io.Closeable; import java.io.IOException; -import java.util.Iterator; import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.io.CloseableIterator; import org.apache.orc.RecordReader; import org.apache.orc.TypeDescription; import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; @@ -32,16 +31,16 @@ * Because the same VectorizedRowBatch is reused on each call to next, * it gets changed when hasNext or next is called. */ -public class VectorizedRowBatchIterator implements Iterator, Closeable { +public class VectorizedRowBatchIterator implements CloseableIterator { private final String fileLocation; private final RecordReader rows; private final VectorizedRowBatch batch; private boolean advanced = false; - VectorizedRowBatchIterator(String fileLocation, TypeDescription schema, RecordReader rows) { + VectorizedRowBatchIterator(String fileLocation, TypeDescription schema, RecordReader rows, int recordsPerBatch) { this.fileLocation = fileLocation; this.rows = rows; - this.batch = schema.createRowBatch(); + this.batch = schema.createRowBatch(recordsPerBatch); } @Override diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java index da1ee6e301df..6e6eb56e3198 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java @@ -32,7 +32,6 @@ import org.apache.orc.storage.ql.exec.vector.StructColumnVector; import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.Decimal; /** * Converts the OrcIterator, which returns ORC's VectorizedRowBatch to a @@ -103,11 +102,7 @@ public OrcValueReader primitive(Type.PrimitiveType iPrimitive, TypeDescriptio case TIMESTAMP_INSTANT: return SparkOrcValueReaders.timestampTzs(); case DECIMAL: - if (primitive.getPrecision() <= Decimal.MAX_LONG_DIGITS()) { - return new SparkOrcValueReaders.Decimal18Reader(primitive.getPrecision(), primitive.getScale()); - } else { - return new SparkOrcValueReaders.Decimal38Reader(primitive.getPrecision(), primitive.getScale()); - } + return SparkOrcValueReaders.decimals(primitive.getPrecision(), primitive.getScale()); case CHAR: case VARCHAR: case STRING: diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java index 5add4994aab0..ab9ee43fc44c 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java @@ -42,19 +42,26 @@ import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.types.UTF8String; - -class SparkOrcValueReaders { +public class SparkOrcValueReaders { private SparkOrcValueReaders() { } - static OrcValueReader utf8String() { + public static OrcValueReader utf8String() { return StringReader.INSTANCE; } - static OrcValueReader timestampTzs() { + public static OrcValueReader timestampTzs() { return TimestampTzReader.INSTANCE; } + public static OrcValueReader decimals(int precision, int scale) { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return new SparkOrcValueReaders.Decimal18Reader(precision, scale); + } else { + return new SparkOrcValueReaders.Decimal38Reader(precision, scale); + } + } + static OrcValueReader struct( List> readers, Types.StructType struct, Map idToConstant) { return new StructReader(readers, struct, idToConstant); @@ -164,7 +171,7 @@ public Long nonNullRead(ColumnVector vector, int row) { } } - static class Decimal18Reader implements OrcValueReader { + private static class Decimal18Reader implements OrcValueReader { //TODO: these are being unused. check for bug private final int precision; private final int scale; @@ -181,7 +188,7 @@ public Decimal nonNullRead(ColumnVector vector, int row) { } } - static class Decimal38Reader implements OrcValueReader { + private static class Decimal38Reader implements OrcValueReader { private final int precision; private final int scale; diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/NullValuesColumnVector.java b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ConstantColumnVector.java similarity index 65% rename from spark/src/main/java/org/apache/iceberg/spark/data/vectorized/NullValuesColumnVector.java rename to spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ConstantColumnVector.java index 8770d13ab883..c3acbc4f0d00 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/NullValuesColumnVector.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ConstantColumnVector.java @@ -21,105 +21,104 @@ import org.apache.iceberg.spark.SparkSchemaUtil; import org.apache.iceberg.types.Type; -import org.apache.iceberg.types.Types; import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.UTF8String; -public class NullValuesColumnVector extends ColumnVector { +class ConstantColumnVector extends ColumnVector { - private final int numNulls; - private static final Type NULL_TYPE = Types.IntegerType.get(); + private final Object constant; + private final int batchSize; - public NullValuesColumnVector(int nValues) { - super(SparkSchemaUtil.convert(NULL_TYPE)); - this.numNulls = nValues; + ConstantColumnVector(Type type, int batchSize, Object constant) { + super(SparkSchemaUtil.convert(type)); + this.constant = constant; + this.batchSize = batchSize; } @Override public void close() { - } @Override public boolean hasNull() { - return true; + return constant == null; } @Override public int numNulls() { - return numNulls; + return constant == null ? batchSize : 0; } @Override public boolean isNullAt(int rowId) { - return true; + return constant == null; } @Override public boolean getBoolean(int rowId) { - throw new UnsupportedOperationException(); + return constant != null ? (boolean) constant : false; } @Override public byte getByte(int rowId) { - throw new UnsupportedOperationException(); + return constant != null ? (byte) constant : 0; } @Override public short getShort(int rowId) { - throw new UnsupportedOperationException(); + return constant != null ? (short) constant : 0; } @Override public int getInt(int rowId) { - throw new UnsupportedOperationException(); + return constant != null ? (int) constant : 0; } @Override public long getLong(int rowId) { - throw new UnsupportedOperationException(); + return constant != null ? (long) constant : 0L; } @Override public float getFloat(int rowId) { - throw new UnsupportedOperationException(); + return constant != null ? (float) constant : 0.0F; } @Override public double getDouble(int rowId) { - throw new UnsupportedOperationException(); + return constant != null ? (double) constant : 0.0; } @Override public ColumnarArray getArray(int rowId) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException("ConstantColumnVector only supports primitives"); } @Override public ColumnarMap getMap(int ordinal) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException("ConstantColumnVector only supports primitives"); } @Override public Decimal getDecimal(int rowId, int precision, int scale) { - throw new UnsupportedOperationException(); + return (Decimal) constant; } @Override public UTF8String getUTF8String(int rowId) { - throw new UnsupportedOperationException(); + return (UTF8String) constant; } @Override public byte[] getBinary(int rowId) { - throw new UnsupportedOperationException(); + return (byte[]) constant; } @Override protected ColumnVector getChild(int ordinal) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException("ConstantColumnVector only supports primitives"); } } diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java index 9d10cd935512..60cd17e06ed3 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java @@ -22,6 +22,7 @@ import org.apache.iceberg.arrow.vectorized.NullabilityHolder; import org.apache.iceberg.arrow.vectorized.VectorHolder; import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Types; import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.vectorized.ArrowColumnVector; import org.apache.spark.sql.vectorized.ColumnVector; @@ -143,7 +144,7 @@ public ArrowColumnVector getChild(int ordinal) { } static ColumnVector forHolder(VectorHolder holder, int numRows) { - return holder.isDummy() ? new NullValuesColumnVector(numRows) : + return holder.isDummy() ? new ConstantColumnVector(Types.IntegerType.get(), numRows, null) : new IcebergArrowColumnVector(holder); } diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java new file mode 100644 index 000000000000..564fcfa0b3da --- /dev/null +++ b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java @@ -0,0 +1,416 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.vectorized; + +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; +import org.apache.iceberg.Schema; +import org.apache.iceberg.orc.OrcBatchReader; +import org.apache.iceberg.orc.OrcSchemaWithTypeVisitor; +import org.apache.iceberg.orc.OrcValueReader; +import org.apache.iceberg.orc.OrcValueReaders; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.SparkOrcValueReaders; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.orc.TypeDescription; +import org.apache.orc.storage.ql.exec.vector.ListColumnVector; +import org.apache.orc.storage.ql.exec.vector.MapColumnVector; +import org.apache.orc.storage.ql.exec.vector.StructColumnVector; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +public class VectorizedSparkOrcReaders { + + private VectorizedSparkOrcReaders() { + } + + public static OrcBatchReader buildReader(Schema expectedSchema, TypeDescription fileSchema, + Map idToConstant) { + Converter converter = OrcSchemaWithTypeVisitor.visit(expectedSchema, fileSchema, new ReadBuilder(idToConstant)); + + return batch -> { + BaseOrcColumnVector cv = (BaseOrcColumnVector) converter.convert(new StructColumnVector(batch.size, batch.cols), + batch.size); + ColumnarBatch columnarBatch = new ColumnarBatch(IntStream.range(0, expectedSchema.columns().size()) + .mapToObj(cv::getChild) + .toArray(ColumnVector[]::new)); + columnarBatch.setNumRows(batch.size); + return columnarBatch; + }; + } + + private interface Converter { + ColumnVector convert(org.apache.orc.storage.ql.exec.vector.ColumnVector columnVector, int batchSize); + } + + private static class ReadBuilder extends OrcSchemaWithTypeVisitor { + private final Map idToConstant; + + private ReadBuilder(Map idToConstant) { + this.idToConstant = idToConstant; + } + + @Override + public Converter record(Types.StructType iStruct, TypeDescription record, List names, + List fields) { + return new StructConverter(iStruct, fields, idToConstant); + } + + @Override + public Converter list(Types.ListType iList, TypeDescription array, Converter element) { + return new ArrayConverter(iList, element); + } + + @Override + public Converter map(Types.MapType iMap, TypeDescription map, Converter key, Converter value) { + return new MapConverter(iMap, key, value); + } + + @Override + public Converter primitive(Type.PrimitiveType iPrimitive, TypeDescription primitive) { + final OrcValueReader primitiveValueReader; + switch (primitive.getCategory()) { + case BOOLEAN: + primitiveValueReader = OrcValueReaders.booleans(); + break; + case BYTE: + // Iceberg does not have a byte type. Use int + case SHORT: + // Iceberg does not have a short type. Use int + case DATE: + case INT: + primitiveValueReader = OrcValueReaders.ints(); + break; + case LONG: + primitiveValueReader = OrcValueReaders.longs(); + break; + case FLOAT: + primitiveValueReader = OrcValueReaders.floats(); + break; + case DOUBLE: + primitiveValueReader = OrcValueReaders.doubles(); + break; + case TIMESTAMP_INSTANT: + primitiveValueReader = SparkOrcValueReaders.timestampTzs(); + break; + case DECIMAL: + primitiveValueReader = SparkOrcValueReaders.decimals(primitive.getPrecision(), primitive.getScale()); + break; + case CHAR: + case VARCHAR: + case STRING: + primitiveValueReader = SparkOrcValueReaders.utf8String(); + break; + case BINARY: + primitiveValueReader = OrcValueReaders.bytes(); + break; + default: + throw new IllegalArgumentException("Unhandled type " + primitive); + } + return (columnVector, batchSize) -> + new PrimitiveOrcColumnVector(iPrimitive, batchSize, columnVector, primitiveValueReader); + } + } + + private abstract static class BaseOrcColumnVector extends ColumnVector { + private final org.apache.orc.storage.ql.exec.vector.ColumnVector vector; + private final int batchSize; + private Integer numNulls; + + BaseOrcColumnVector(Type type, int batchSize, org.apache.orc.storage.ql.exec.vector.ColumnVector vector) { + super(SparkSchemaUtil.convert(type)); + this.vector = vector; + this.batchSize = batchSize; + } + + @Override + public void close() { + } + + @Override + public boolean hasNull() { + return !vector.noNulls; + } + + @Override + public int numNulls() { + if (numNulls == null) { + numNulls = numNullsHelper(); + } + return numNulls; + } + + private int numNullsHelper() { + if (vector.isRepeating) { + if (vector.isNull[0]) { + return batchSize; + } else { + return 0; + } + } else if (vector.noNulls) { + return 0; + } else { + int count = 0; + for (int i = 0; i < batchSize; i++) { + if (vector.isNull[i]) { + count++; + } + } + return count; + } + } + + protected int getRowIndex(int rowId) { + return vector.isRepeating ? 0 : rowId; + } + + @Override + public boolean isNullAt(int rowId) { + return vector.isNull[getRowIndex(rowId)]; + } + + @Override + public boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + @Override + public UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } + } + + private static class PrimitiveOrcColumnVector extends BaseOrcColumnVector { + private final org.apache.orc.storage.ql.exec.vector.ColumnVector vector; + private final OrcValueReader primitiveValueReader; + + PrimitiveOrcColumnVector(Type type, int batchSize, org.apache.orc.storage.ql.exec.vector.ColumnVector vector, + OrcValueReader primitiveValueReader) { + super(type, batchSize, vector); + this.vector = vector; + this.primitiveValueReader = primitiveValueReader; + } + + @Override + public boolean getBoolean(int rowId) { + Boolean value = (Boolean) primitiveValueReader.read(vector, rowId); + return value != null ? value : false; + } + + @Override + public int getInt(int rowId) { + Integer value = (Integer) primitiveValueReader.read(vector, rowId); + return value != null ? value : 0; + } + + @Override + public long getLong(int rowId) { + Long value = (Long) primitiveValueReader.read(vector, rowId); + return value != null ? value : 0L; + } + + @Override + public float getFloat(int rowId) { + Float value = (Float) primitiveValueReader.read(vector, rowId); + return value != null ? value : 0.0F; + } + + @Override + public double getDouble(int rowId) { + Double value = (Double) primitiveValueReader.read(vector, rowId); + return value != null ? value : 0.0; + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + // TODO: Is it okay to assume that (precision,scale) parameters == (precision,scale) of the decimal type + // and return a Decimal with (precision,scale) of the decimal type? + return (Decimal) primitiveValueReader.read(vector, rowId); + } + + @Override + public UTF8String getUTF8String(int rowId) { + return (UTF8String) primitiveValueReader.read(vector, rowId); + } + + @Override + public byte[] getBinary(int rowId) { + return (byte[]) primitiveValueReader.read(vector, rowId); + } + } + + private static class ArrayConverter implements Converter { + private final Types.ListType listType; + private final Converter elementConverter; + + private ArrayConverter(Types.ListType listType, Converter elementConverter) { + this.listType = listType; + this.elementConverter = elementConverter; + } + + @Override + public ColumnVector convert(org.apache.orc.storage.ql.exec.vector.ColumnVector vector, int batchSize) { + ListColumnVector listVector = (ListColumnVector) vector; + ColumnVector elementVector = elementConverter.convert(listVector.child, batchSize); + + return new BaseOrcColumnVector(listType, batchSize, vector) { + @Override + public ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) { + return null; + } else { + int index = getRowIndex(rowId); + return new ColumnarArray(elementVector, (int) listVector.offsets[index], (int) listVector.lengths[index]); + } + } + }; + } + } + + private static class MapConverter implements Converter { + private final Types.MapType mapType; + private final Converter keyConverter; + private final Converter valueConverter; + + private MapConverter(Types.MapType mapType, Converter keyConverter, Converter valueConverter) { + this.mapType = mapType; + this.keyConverter = keyConverter; + this.valueConverter = valueConverter; + } + + @Override + public ColumnVector convert(org.apache.orc.storage.ql.exec.vector.ColumnVector vector, int batchSize) { + MapColumnVector mapVector = (MapColumnVector) vector; + ColumnVector keyVector = keyConverter.convert(mapVector.keys, batchSize); + ColumnVector valueVector = valueConverter.convert(mapVector.values, batchSize); + + return new BaseOrcColumnVector(mapType, batchSize, vector) { + @Override + public ColumnarMap getMap(int rowId) { + if (isNullAt(rowId)) { + return null; + } else { + int index = getRowIndex(rowId); + return new ColumnarMap(keyVector, valueVector, (int) mapVector.offsets[index], + (int) mapVector.lengths[index]); + } + } + }; + } + } + + private static class StructConverter implements Converter { + private final Types.StructType structType; + private final List fieldConverters; + private final Map idToConstant; + + private StructConverter(Types.StructType structType, List fieldConverters, + Map idToConstant) { + this.structType = structType; + this.fieldConverters = fieldConverters; + this.idToConstant = idToConstant; + } + + @Override + public ColumnVector convert(org.apache.orc.storage.ql.exec.vector.ColumnVector vector, int batchSize) { + StructColumnVector structVector = (StructColumnVector) vector; + List fields = structType.fields(); + List fieldVectors = Lists.newArrayListWithExpectedSize(fields.size()); + for (int pos = 0, vectorIndex = 0; pos < fields.size(); pos += 1) { + Types.NestedField field = fields.get(pos); + if (idToConstant.containsKey(field.fieldId())) { + fieldVectors.add(new ConstantColumnVector(field.type(), batchSize, idToConstant.get(field.fieldId()))); + } else { + fieldVectors.add(fieldConverters.get(vectorIndex).convert(structVector.fields[vectorIndex], batchSize)); + vectorIndex++; + } + } + + return new BaseOrcColumnVector(structType, batchSize, vector) { + @Override + public ColumnVector getChild(int ordinal) { + return fieldVectors.get(ordinal); + } + }; + } + } +} diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java b/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java index 93e03aaabee1..fc87f3f5ce6b 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java +++ b/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java @@ -21,8 +21,12 @@ import java.io.Closeable; import java.io.IOException; +import java.math.BigDecimal; +import java.nio.ByteBuffer; import java.util.Iterator; import java.util.Map; +import org.apache.avro.generic.GenericData; +import org.apache.avro.util.Utf8; import org.apache.iceberg.CombinedScanTask; import org.apache.iceberg.FileScanTask; import org.apache.iceberg.encryption.EncryptedFiles; @@ -33,7 +37,11 @@ import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.util.ByteBuffers; import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; /** * Base class of Spark readers. @@ -99,4 +107,32 @@ InputFile getInputFile(FileScanTask task) { Preconditions.checkArgument(!task.isDataTask(), "Invalid task type"); return inputFiles.get(task.file().path().toString()); } + + protected static Object convertConstant(Type type, Object value) { + if (value == null) { + return null; + } + + switch (type.typeId()) { + case DECIMAL: + return Decimal.apply((BigDecimal) value); + case STRING: + if (value instanceof Utf8) { + Utf8 utf8 = (Utf8) value; + return UTF8String.fromBytes(utf8.getBytes(), 0, utf8.getByteLength()); + } + return UTF8String.fromString(value.toString()); + case FIXED: + if (value instanceof byte[]) { + return value; + } else if (value instanceof GenericData.Fixed) { + return ((GenericData.Fixed) value).bytes(); + } + return ByteBuffers.toByteArray((ByteBuffer) value); + case BINARY: + return ByteBuffers.toByteArray((ByteBuffer) value); + default: + } + return value; + } } diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java b/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java index f784b638d376..eff18ca3100d 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java +++ b/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java @@ -19,10 +19,14 @@ package org.apache.iceberg.spark.source; +import java.util.Map; +import java.util.Set; import org.apache.arrow.vector.NullCheckingForGet; import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.DataFile; import org.apache.iceberg.FileFormat; import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Schema; import org.apache.iceberg.encryption.EncryptionManager; import org.apache.iceberg.io.CloseableIterable; @@ -30,9 +34,15 @@ import org.apache.iceberg.io.FileIO; import org.apache.iceberg.io.InputFile; import org.apache.iceberg.mapping.NameMappingParser; +import org.apache.iceberg.orc.ORC; import org.apache.iceberg.parquet.Parquet; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders; import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.util.PartitionUtil; +import org.apache.spark.rdd.InputFileBlockHolder; import org.apache.spark.sql.vectorized.ColumnarBatch; class BatchDataReader extends BaseDataReader { @@ -53,6 +63,24 @@ class BatchDataReader extends BaseDataReader { @Override CloseableIterator open(FileScanTask task) { + DataFile file = task.file(); + + // update the current file for Spark's filename() function + InputFileBlockHolder.set(file.path().toString(), task.start(), task.length()); + + // schema or rows returned by readers + PartitionSpec spec = task.spec(); + Set idColumns = spec.identitySourceIds(); + Schema partitionSchema = TypeUtil.select(expectedSchema, idColumns); + boolean projectsIdentityPartitionColumns = !partitionSchema.columns().isEmpty(); + + Map idToConstant; + if (projectsIdentityPartitionColumns) { + idToConstant = PartitionUtil.constantsMap(task, BatchDataReader::convertConstant); + } else { + idToConstant = ImmutableMap.of(); + } + CloseableIterable iter; InputFile location = getInputFile(task); Preconditions.checkNotNull(location, "Could not find InputFile associated with FileScanTask"); @@ -75,6 +103,17 @@ CloseableIterator open(FileScanTask task) { } iter = builder.build(); + } else if (task.file().format() == FileFormat.ORC) { + Schema schemaWithoutConstants = TypeUtil.selectNot(expectedSchema, idToConstant.keySet()); + iter = ORC.read(location) + .project(schemaWithoutConstants) + .split(task.start(), task.length()) + .createBatchedReaderFunc(fileSchema -> VectorizedSparkOrcReaders.buildReader(expectedSchema, fileSchema, + idToConstant)) + .recordsPerBatch(batchSize) + .filter(task.residual()) + .caseSensitive(caseSensitive) + .build(); } else { throw new UnsupportedOperationException( "Format: " + task.file().format() + " not supported for batched reads"); diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java b/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java index 16d5cb9cffca..ff133ed34cff 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java +++ b/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java @@ -19,13 +19,9 @@ package org.apache.iceberg.spark.source; -import java.math.BigDecimal; -import java.nio.ByteBuffer; import java.util.List; import java.util.Map; import java.util.Set; -import org.apache.avro.generic.GenericData; -import org.apache.avro.util.Utf8; import org.apache.iceberg.CombinedScanTask; import org.apache.iceberg.DataFile; import org.apache.iceberg.DataTask; @@ -49,19 +45,15 @@ import org.apache.iceberg.spark.data.SparkAvroReader; import org.apache.iceberg.spark.data.SparkOrcReader; import org.apache.iceberg.spark.data.SparkParquetReaders; -import org.apache.iceberg.types.Type; import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; -import org.apache.iceberg.util.ByteBuffers; import org.apache.iceberg.util.PartitionUtil; import org.apache.spark.rdd.InputFileBlockHolder; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.Attribute; import org.apache.spark.sql.catalyst.expressions.AttributeReference; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; -import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.StructType; -import org.apache.spark.unsafe.types.UTF8String; import scala.collection.JavaConverters; class RowDataReader extends BaseDataReader { @@ -217,32 +209,4 @@ private static UnsafeProjection projection(Schema finalSchema, Schema readSchema JavaConverters.asScalaBufferConverter(exprs).asScala().toSeq(), JavaConverters.asScalaBufferConverter(attrs).asScala().toSeq()); } - - private static Object convertConstant(Type type, Object value) { - if (value == null) { - return null; - } - - switch (type.typeId()) { - case DECIMAL: - return Decimal.apply((BigDecimal) value); - case STRING: - if (value instanceof Utf8) { - Utf8 utf8 = (Utf8) value; - return UTF8String.fromBytes(utf8.getBytes(), 0, utf8.getByteLength()); - } - return UTF8String.fromString(value.toString()); - case FIXED: - if (value instanceof byte[]) { - return value; - } else if (value instanceof GenericData.Fixed) { - return ((GenericData.Fixed) value).bytes(); - } - return ByteBuffers.toByteArray((ByteBuffer) value); - case BINARY: - return ByteBuffers.toByteArray((ByteBuffer) value); - default: - } - return value; - } } diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java b/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java index f603757c2c44..aa0b24785cf0 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java +++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java @@ -638,7 +638,9 @@ private static void assertEquals(String context, StructType struct, for (int i = 0; i < actual.numFields(); i += 1) { StructField field = struct.fields()[i]; DataType type = field.dataType(); - assertEquals(context + "." + field.name(), type, expected.get(i, type), actual.get(i, type)); + assertEquals(context + "." + field.name(), type, + expected.isNullAt(i) ? null : expected.get(i, type), + actual.isNullAt(i) ? null : actual.get(i, type)); } } diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java index 6a588504065a..5042d1cc1338 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java +++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java @@ -29,8 +29,12 @@ import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.io.FileAppender; import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterators; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders; import org.apache.iceberg.types.Types; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.junit.Assert; import org.junit.Test; @@ -81,5 +85,23 @@ private void writeAndValidateRecords(Schema schema, Iterable expect } Assert.assertFalse("Should not have extra rows", actualRows.hasNext()); } + + try (CloseableIterable reader = ORC.read(Files.localInput(testFile)) + .project(schema) + .createBatchedReaderFunc(readOrcSchema -> + VectorizedSparkOrcReaders.buildReader(schema, readOrcSchema, ImmutableMap.of())) + .build()) { + final Iterator actualRows = batchesToRows(reader.iterator()); + final Iterator expectedRows = expected.iterator(); + while (expectedRows.hasNext()) { + Assert.assertTrue("Should have expected number of rows", actualRows.hasNext()); + assertEquals(schema, expectedRows.next(), actualRows.next()); + } + Assert.assertFalse("Should not have extra rows", actualRows.hasNext()); + } + } + + private Iterator batchesToRows(Iterator batches) { + return Iterators.concat(Iterators.transform(batches, ColumnarBatch::rowIterator)); } } diff --git a/spark/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData.java b/spark/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData.java index 759fcf35f278..15ae2f5f3176 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData.java +++ b/spark/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData.java @@ -53,16 +53,20 @@ public abstract class TestIdentityPartitionData { @Parameterized.Parameters public static Object[][] parameters() { return new Object[][] { - new Object[] { "parquet" }, - new Object[] { "avro" }, - new Object[] { "orc" } + new Object[] { "parquet", false }, + new Object[] { "parquet", true }, + new Object[] { "avro", false }, + new Object[] { "orc", false }, + new Object[] { "orc", true }, }; } private final String format; + private final boolean vectorized; - public TestIdentityPartitionData(String format) { + public TestIdentityPartitionData(String format, boolean vectorized) { this.format = format; + this.vectorized = vectorized; } private static SparkSession spark = null; @@ -121,7 +125,9 @@ public void setupTable() throws Exception { @Test public void testFullProjection() { List expected = logs.orderBy("id").collectAsList(); - List actual = spark.read().format("iceberg").load(table.location()).orderBy("id").collectAsList(); + List actual = spark.read().format("iceberg") + .option("vectorization-enabled", String.valueOf(vectorized)) + .load(table.location()).orderBy("id").collectAsList(); Assert.assertEquals("Rows should match", expected, actual); } @@ -152,7 +158,9 @@ public void testProjections() { for (String[] ordering : cases) { List expected = logs.select("id", ordering).orderBy("id").collectAsList(); List actual = spark.read() - .format("iceberg").load(table.location()) + .format("iceberg") + .option("vectorization-enabled", String.valueOf(vectorized)) + .load(table.location()) .select("id", ordering).orderBy("id") .collectAsList(); Assert.assertEquals("Rows should match for ordering: " + Arrays.toString(ordering), expected, actual); diff --git a/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues.java b/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues.java index 07b21748d51d..c46b19166ae5 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues.java +++ b/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues.java @@ -62,9 +62,11 @@ public abstract class TestPartitionValues { @Parameterized.Parameters public static Object[][] parameters() { return new Object[][] { - new Object[] { "parquet" }, - new Object[] { "avro" }, - new Object[] { "orc" } + new Object[] { "parquet", false }, + new Object[] { "parquet", true }, + new Object[] { "avro", false }, + new Object[] { "orc", false }, + new Object[] { "orc", true } }; } @@ -111,9 +113,11 @@ public static void stopSpark() { public TemporaryFolder temp = new TemporaryFolder(); private final String format; + private final boolean vectorized; - public TestPartitionValues(String format) { + public TestPartitionValues(String format, boolean vectorized) { this.format = format; + this.vectorized = vectorized; } @Test @@ -144,6 +148,7 @@ public void testNullPartitionValue() throws Exception { Dataset result = spark.read() .format("iceberg") + .option("vectorization-enabled", String.valueOf(vectorized)) .load(location.toString()); List actual = result @@ -183,6 +188,7 @@ public void testReorderedColumns() throws Exception { Dataset result = spark.read() .format("iceberg") + .option("vectorization-enabled", String.valueOf(vectorized)) .load(location.toString()); List actual = result @@ -223,6 +229,7 @@ public void testReorderedColumnsNoNullability() throws Exception { Dataset result = spark.read() .format("iceberg") + .option("vectorization-enabled", String.valueOf(vectorized)) .load(location.toString()); List actual = result @@ -261,7 +268,9 @@ public void testPartitionValueTypes() throws Exception { .appendFile(DataFiles.fromInputFile(Files.localInput(avroData), 10)) .commit(); - Dataset sourceDF = spark.read().format("iceberg").load(sourceLocation); + Dataset sourceDF = spark.read().format("iceberg") + .option("vectorization-enabled", String.valueOf(vectorized)) + .load(sourceLocation); for (String column : columnNames) { String desc = "partition_by_" + SUPPORTED_PRIMITIVES.findType(column).toString(); @@ -283,6 +292,7 @@ public void testPartitionValueTypes() throws Exception { List actual = spark.read() .format("iceberg") + .option("vectorization-enabled", String.valueOf(vectorized)) .load(location.toString()) .collectAsList(); @@ -323,7 +333,9 @@ public void testNestedPartitionValues() throws Exception { .appendFile(DataFiles.fromInputFile(Files.localInput(avroData), 10)) .commit(); - Dataset sourceDF = spark.read().format("iceberg").load(sourceLocation); + Dataset sourceDF = spark.read().format("iceberg") + .option("vectorization-enabled", String.valueOf(vectorized)) + .load(sourceLocation); for (String column : columnNames) { String desc = "partition_by_" + SUPPORTED_PRIMITIVES.findType(column).toString(); @@ -345,6 +357,7 @@ public void testNestedPartitionValues() throws Exception { List actual = spark.read() .format("iceberg") + .option("vectorization-enabled", String.valueOf(vectorized)) .load(location.toString()) .collectAsList(); @@ -403,6 +416,7 @@ public void testPartitionedByNestedString() throws Exception { // verify List actual = spark.read() .format("iceberg") + .option("vectorization-enabled", String.valueOf(vectorized)) .load(baseLocation) .collectAsList(); diff --git a/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java b/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java index 6c9b32bcdc5e..ac64fa952c71 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java +++ b/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java @@ -31,7 +31,6 @@ import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Schema; import org.apache.iceberg.Table; -import org.apache.iceberg.TableProperties; import org.apache.iceberg.avro.Avro; import org.apache.iceberg.data.Record; import org.apache.iceberg.data.avro.DataWriter; @@ -70,7 +69,8 @@ public static Object[][] parameters() { new Object[] { "parquet", false }, new Object[] { "parquet", true }, new Object[] { "avro", false }, - new Object[] { "orc", false } + new Object[] { "orc", false }, + new Object[] { "orc", true } }; } @@ -148,8 +148,6 @@ protected Record writeAndRead(String desc, Schema writeSchema, Schema readSchema table.newAppend().appendFile(file).commit(); - table.updateProperties().set(TableProperties.PARQUET_VECTORIZATION_ENABLED, String.valueOf(vectorized)).commit(); - // rewrite the read schema for the table's reassigned ids Map idMapping = Maps.newHashMap(); for (int id : allIds(writeSchema)) { @@ -166,6 +164,7 @@ protected Record writeAndRead(String desc, Schema writeSchema, Schema readSchema Dataset df = spark.read() .format("org.apache.iceberg.spark.source.TestIcebergSource") .option("iceberg.table.name", desc) + .option("vectorization-enabled", String.valueOf(vectorized)) .load(); return SparkValueConverter.convert(readSchema, df.collectAsList().get(0)); diff --git a/spark2/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataReadBenchmark.java b/spark2/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataReadBenchmark.java index 5463c7f9b86e..811d15c242cb 100644 --- a/spark2/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataReadBenchmark.java +++ b/spark2/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataReadBenchmark.java @@ -66,7 +66,7 @@ public void tearDownBenchmark() throws IOException { @Benchmark @Threads(1) - public void readIceberg() { + public void readIcebergNonVectorized() { Map tableProperties = Maps.newHashMap(); tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); withTableProperties(tableProperties, () -> { @@ -76,6 +76,19 @@ public void readIceberg() { }); } + @Benchmark + @Threads(1) + public void readIcebergVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().option("vectorization-enabled", "true") + .format("iceberg").load(tableLocation); + materialize(df); + }); + } + @Benchmark @Threads(1) public void readFileSourceVectorized() { @@ -102,7 +115,7 @@ public void readFileSourceNonVectorized() { @Benchmark @Threads(1) - public void readWithProjectionIceberg() { + public void readWithProjectionIcebergNonVectorized() { Map tableProperties = Maps.newHashMap(); tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); withTableProperties(tableProperties, () -> { @@ -112,6 +125,20 @@ public void readWithProjectionIceberg() { }); } + @Benchmark + @Threads(1) + public void readWithProjectionIcebergVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().option("vectorization-enabled", "true") + .format("iceberg").load(tableLocation).select("longCol"); + materialize(df); + }); + } + + @Benchmark @Threads(1) public void readWithProjectionFileSourceVectorized() { diff --git a/spark2/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedORCDataReadBenchmark.java b/spark2/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedORCDataReadBenchmark.java index a4147e65ce1b..a63d4f9a1083 100644 --- a/spark2/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedORCDataReadBenchmark.java +++ b/spark2/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedORCDataReadBenchmark.java @@ -68,7 +68,7 @@ public void tearDownBenchmark() throws IOException { @Benchmark @Threads(1) - public void readIceberg() { + public void readIcebergNonVectorized() { Map tableProperties = Maps.newHashMap(); tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); withTableProperties(tableProperties, () -> { @@ -80,12 +80,13 @@ public void readIceberg() { @Benchmark @Threads(1) - public void readFileSourceVectorized() { - Map conf = Maps.newHashMap(); - conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "true"); - conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); - withSQLConf(conf, () -> { - Dataset df = spark().read().orc(dataLocation()); + public void readIcebergVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().option("vectorization-enabled", "true") + .format("iceberg").load(tableLocation); materialize(df); }); } @@ -104,7 +105,7 @@ public void readFileSourceNonVectorized() { @Benchmark @Threads(1) - public void readWithProjectionIceberg() { + public void readWithProjectionIcebergNonVectorized() { Map tableProperties = Maps.newHashMap(); tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); withTableProperties(tableProperties, () -> { @@ -116,12 +117,13 @@ public void readWithProjectionIceberg() { @Benchmark @Threads(1) - public void readWithProjectionFileSourceVectorized() { - Map conf = Maps.newHashMap(); - conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "true"); - conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); - withSQLConf(conf, () -> { - Dataset df = spark().read().orc(dataLocation()).selectExpr("nested.col3"); + public void readWithProjectionIcebergVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().option("vectorization-enabled", "true") + .format("iceberg").load(tableLocation).selectExpr("nested.col3"); materialize(df); }); } diff --git a/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java b/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java index 51859de75ac2..9fb475bcd551 100644 --- a/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java +++ b/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java @@ -299,6 +299,13 @@ public boolean enableBatchRead() { .allMatch(fileScanTask -> fileScanTask.file().format().equals( FileFormat.PARQUET))); + boolean allOrcFileScanTasks = + tasks().stream() + .allMatch(combinedScanTask -> !combinedScanTask.isDataTask() && combinedScanTask.files() + .stream() + .allMatch(fileScanTask -> fileScanTask.file().format().equals( + FileFormat.ORC))); + boolean atLeastOneColumn = lazySchema().columns().size() > 0; boolean hasNoIdentityProjections = tasks().stream() @@ -308,8 +315,8 @@ public boolean enableBatchRead() { boolean onlyPrimitives = lazySchema().columns().stream().allMatch(c -> c.type().isPrimitiveType()); - this.readUsingBatch = batchReadsEnabled && allParquetFileScanTasks && atLeastOneColumn && - hasNoIdentityProjections && onlyPrimitives; + this.readUsingBatch = batchReadsEnabled && (allOrcFileScanTasks || + (allParquetFileScanTasks && atLeastOneColumn && hasNoIdentityProjections && onlyPrimitives)); } return readUsingBatch; } diff --git a/spark2/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java b/spark2/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java index c0d676ec8aa9..0d45179b315c 100644 --- a/spark2/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java +++ b/spark2/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java @@ -152,18 +152,22 @@ public static void stopSpark() { public TemporaryFolder temp = new TemporaryFolder(); private final String format; + private final boolean vectorized; @Parameterized.Parameters public static Object[][] parameters() { return new Object[][] { - new Object[] { "parquet" }, - new Object[] { "avro" }, - new Object[] { "orc" } + new Object[] { "parquet", false }, + new Object[] { "parquet", true }, + new Object[] { "avro", false }, + new Object[] { "orc", false }, + new Object[] { "orc", true } }; } - public TestFilteredScan(String format) { + public TestFilteredScan(String format, boolean vectorized) { this.format = format; + this.vectorized = vectorized; } private File parent = null; @@ -243,7 +247,7 @@ public void testUnpartitionedIDFilters() { // validate row filtering assertEqualsSafe(SCHEMA.asStruct(), expected(i), - read(unpartitioned.toString(), "id = " + i)); + read(unpartitioned.toString(), vectorized, "id = " + i)); } } @@ -270,7 +274,7 @@ public void testUnpartitionedCaseInsensitiveIDFilters() { // validate row filtering assertEqualsSafe(SCHEMA.asStruct(), expected(i), - read(unpartitioned.toString(), "id = " + i)); + read(unpartitioned.toString(), vectorized, "id = " + i)); } } finally { // return global conf to previous state @@ -294,7 +298,7 @@ public void testUnpartitionedTimestampFilter() { Assert.assertEquals("Should only create one task for a small file", 1, tasks.size()); assertEqualsSafe(SCHEMA.asStruct(), expected(5, 6, 7, 8, 9), - read(unpartitioned.toString(), "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); + read(unpartitioned.toString(), vectorized, "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); } @Test @@ -321,7 +325,7 @@ public void testBucketPartitionedIDFilters() { Assert.assertEquals("Should create one task for a single bucket", 1, tasks.size()); // validate row filtering - assertEqualsSafe(SCHEMA.asStruct(), expected(i), read(location.toString(), "id = " + i)); + assertEqualsSafe(SCHEMA.asStruct(), expected(i), read(location.toString(), vectorized, "id = " + i)); } } @@ -348,7 +352,7 @@ public void testDayPartitionedTimestampFilters() { Assert.assertEquals("Should create one task for 2017-12-21", 1, tasks.size()); assertEqualsSafe(SCHEMA.asStruct(), expected(5, 6, 7, 8, 9), - read(location.toString(), "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); + read(location.toString(), vectorized, "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); } { @@ -361,7 +365,7 @@ public void testDayPartitionedTimestampFilters() { List> tasks = reader.planInputPartitions(); Assert.assertEquals("Should create one task for 2017-12-22", 1, tasks.size()); - assertEqualsSafe(SCHEMA.asStruct(), expected(1, 2), read(location.toString(), + assertEqualsSafe(SCHEMA.asStruct(), expected(1, 2), read(location.toString(), vectorized, "ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " + "ts < cast('2017-12-22 08:00:00+00:00' as timestamp)")); } @@ -390,7 +394,7 @@ public void testHourPartitionedTimestampFilters() { Assert.assertEquals("Should create 4 tasks for 2017-12-21: 15, 17, 21, 22", 4, tasks.size()); assertEqualsSafe(SCHEMA.asStruct(), expected(8, 9, 7, 6, 5), - read(location.toString(), "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); + read(location.toString(), vectorized, "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); } { @@ -403,7 +407,7 @@ public void testHourPartitionedTimestampFilters() { List> tasks = reader.planInputPartitions(); Assert.assertEquals("Should create 2 tasks for 2017-12-22: 6, 7", 2, tasks.size()); - assertEqualsSafe(SCHEMA.asStruct(), expected(2, 1), read(location.toString(), + assertEqualsSafe(SCHEMA.asStruct(), expected(2, 1), read(location.toString(), vectorized, "ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " + "ts < cast('2017-12-22 08:00:00+00:00' as timestamp)")); } @@ -420,7 +424,7 @@ public void testFilterByNonProjectedColumn() { } assertEqualsSafe(actualProjection.asStruct(), expected, read( - unpartitioned.toString(), + unpartitioned.toString(), vectorized, "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)", "id", "data")); } @@ -435,7 +439,7 @@ public void testFilterByNonProjectedColumn() { } assertEqualsSafe(actualProjection.asStruct(), expected, read( - unpartitioned.toString(), + unpartitioned.toString(), vectorized, "ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " + "ts < cast('2017-12-22 08:00:00+00:00' as timestamp)", "id")); @@ -512,6 +516,7 @@ public void testPartitionedByIdStartsWith() { public void testUnpartitionedStartsWith() { Dataset df = spark.read() .format("iceberg") + .option("vectorization-enabled", String.valueOf(vectorized)) .load(unpartitioned.toString()); List matchedData = df.select("data") @@ -578,6 +583,7 @@ private File buildPartitionedTable(String desc, PartitionSpec spec, String udf, // copy the unpartitioned table into the partitioned table to produce the partitioned data Dataset allRows = spark.read() .format("iceberg") + .option("vectorization-enabled", String.valueOf(vectorized)) .load(unpartitioned.toString()); allRows @@ -608,12 +614,14 @@ private List testRecords(Schema schema) { ); } - private static List read(String table, String expr) { - return read(table, expr, "*"); + private static List read(String table, boolean vectorized, String expr) { + return read(table, vectorized, expr, "*"); } - private static List read(String table, String expr, String select0, String... selectN) { - Dataset dataset = spark.read().format("iceberg").load(table).filter(expr) + private static List read(String table, boolean vectorized, String expr, String select0, String... selectN) { + Dataset dataset = spark.read().format("iceberg") + .option("vectorization-enabled", String.valueOf(vectorized)) + .load(table).filter(expr) .select(select0, selectN); return dataset.collectAsList(); } diff --git a/spark2/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData24.java b/spark2/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData24.java index fd7db75ed41f..9e382bfcf2cf 100644 --- a/spark2/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData24.java +++ b/spark2/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData24.java @@ -20,7 +20,7 @@ package org.apache.iceberg.spark.source; public class TestIdentityPartitionData24 extends TestIdentityPartitionData { - public TestIdentityPartitionData24(String format) { - super(format); + public TestIdentityPartitionData24(String format, boolean vectorized) { + super(format, vectorized); } } diff --git a/spark2/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues24.java b/spark2/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues24.java index f9da71eb2178..d5d891f8ced0 100644 --- a/spark2/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues24.java +++ b/spark2/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues24.java @@ -20,7 +20,7 @@ package org.apache.iceberg.spark.source; public class TestPartitionValues24 extends TestPartitionValues { - public TestPartitionValues24(String format) { - super(format); + public TestPartitionValues24(String format, boolean vectorized) { + super(format, vectorized); } } diff --git a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java index 026f6baccce4..912d90cafb53 100644 --- a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java +++ b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java @@ -155,6 +155,13 @@ public PartitionReaderFactory createReaderFactory() { .allMatch(fileScanTask -> fileScanTask.file().format().equals( FileFormat.PARQUET))); + boolean allOrcFileScanTasks = + tasks().stream() + .allMatch(combinedScanTask -> !combinedScanTask.isDataTask() && combinedScanTask.files() + .stream() + .allMatch(fileScanTask -> fileScanTask.file().format().equals( + FileFormat.ORC))); + boolean atLeastOneColumn = expectedSchema.columns().size() > 0; boolean hasNoIdentityProjections = tasks().stream() @@ -164,8 +171,8 @@ public PartitionReaderFactory createReaderFactory() { boolean onlyPrimitives = expectedSchema.columns().stream().allMatch(c -> c.type().isPrimitiveType()); - boolean readUsingBatch = batchReadsEnabled && allParquetFileScanTasks && atLeastOneColumn && - hasNoIdentityProjections && onlyPrimitives; + boolean readUsingBatch = batchReadsEnabled && (allOrcFileScanTasks || + (allParquetFileScanTasks && atLeastOneColumn && hasNoIdentityProjections && onlyPrimitives)); return new ReaderFactory(readUsingBatch ? batchSize : 0); } diff --git a/spark3/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java b/spark3/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java index 7dd308dcbed1..9be99383873f 100644 --- a/spark3/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java +++ b/spark3/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java @@ -22,10 +22,10 @@ import java.io.File; import java.io.IOException; import java.sql.Timestamp; +import java.time.OffsetDateTime; import java.util.List; import java.util.Locale; import java.util.UUID; -import org.apache.avro.generic.GenericData.Record; import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.DataFile; import org.apache.iceberg.DataFiles; @@ -34,14 +34,18 @@ import org.apache.iceberg.Schema; import org.apache.iceberg.Table; import org.apache.iceberg.avro.Avro; -import org.apache.iceberg.avro.AvroSchemaUtil; -import org.apache.iceberg.expressions.Literal; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.data.avro.DataWriter; +import org.apache.iceberg.data.orc.GenericOrcWriter; +import org.apache.iceberg.data.parquet.GenericParquetWriter; import org.apache.iceberg.hadoop.HadoopTables; import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.orc.ORC; import org.apache.iceberg.parquet.Parquet; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; -import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.data.GenericsHelpers; import org.apache.iceberg.transforms.Transform; import org.apache.iceberg.transforms.Transforms; import org.apache.iceberg.types.Types; @@ -146,17 +150,22 @@ public static void stopSpark() { public TemporaryFolder temp = new TemporaryFolder(); private final String format; + private final boolean vectorized; @Parameterized.Parameters public static Object[][] parameters() { return new Object[][] { - new Object[] { "parquet" }, - new Object[] { "avro" } + new Object[] { "parquet", false }, + new Object[] { "parquet", true }, + new Object[] { "avro", false }, + new Object[] { "orc", false }, + new Object[] { "orc", true } }; } - public TestFilteredScan(String format) { + public TestFilteredScan(String format, boolean vectorized) { this.format = format; + this.vectorized = vectorized; } private File parent = null; @@ -177,13 +186,12 @@ public void writeUnpartitionedTable() throws IOException { File testFile = new File(dataFolder, fileFormat.addExtension(UUID.randomUUID().toString())); - // create records using the table's schema - org.apache.avro.Schema avroSchema = AvroSchemaUtil.convert(tableSchema, "test"); - this.records = testRecords(avroSchema); + this.records = testRecords(tableSchema); switch (fileFormat) { case AVRO: try (FileAppender writer = Avro.write(localOutput(testFile)) + .createWriterFunc(DataWriter::create) .schema(tableSchema) .build()) { writer.addAll(records); @@ -192,6 +200,16 @@ public void writeUnpartitionedTable() throws IOException { case PARQUET: try (FileAppender writer = Parquet.write(localOutput(testFile)) + .createWriterFunc(GenericParquetWriter::buildWriter) + .schema(tableSchema) + .build()) { + writer.addAll(records); + } + break; + + case ORC: + try (FileAppender writer = ORC.write(localOutput(testFile)) + .createWriterFunc(GenericOrcWriter::buildWriter) .schema(tableSchema) .build()) { writer.addAll(records); @@ -224,7 +242,7 @@ public void testUnpartitionedIDFilters() { // validate row filtering assertEqualsSafe(SCHEMA.asStruct(), expected(i), - read(unpartitioned.toString(), "id = " + i)); + read(unpartitioned.toString(), vectorized, "id = " + i)); } } @@ -252,7 +270,7 @@ public void testUnpartitionedCaseInsensitiveIDFilters() { // validate row filtering assertEqualsSafe(SCHEMA.asStruct(), expected(i), - read(unpartitioned.toString(), "id = " + i)); + read(unpartitioned.toString(), vectorized, "id = " + i)); } } finally { // return global conf to previous state @@ -275,7 +293,7 @@ public void testUnpartitionedTimestampFilter() { Assert.assertEquals("Should only create one task for a small file", 1, tasks.length); assertEqualsSafe(SCHEMA.asStruct(), expected(5, 6, 7, 8, 9), - read(unpartitioned.toString(), "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); + read(unpartitioned.toString(), vectorized, "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); } @Test @@ -299,7 +317,7 @@ public void testBucketPartitionedIDFilters() { Assert.assertEquals("Should create one task for a single bucket", 1, tasks.length); // validate row filtering - assertEqualsSafe(SCHEMA.asStruct(), expected(i), read(table.location(), "id = " + i)); + assertEqualsSafe(SCHEMA.asStruct(), expected(i), read(table.location(), vectorized, "id = " + i)); } } @@ -323,7 +341,7 @@ public void testDayPartitionedTimestampFilters() { Assert.assertEquals("Should create one task for 2017-12-21", 1, tasks.length); assertEqualsSafe(SCHEMA.asStruct(), expected(5, 6, 7, 8, 9), - read(table.location(), "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); + read(table.location(), vectorized, "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); } { @@ -337,7 +355,7 @@ public void testDayPartitionedTimestampFilters() { InputPartition[] tasks = scan.planInputPartitions(); Assert.assertEquals("Should create one task for 2017-12-22", 1, tasks.length); - assertEqualsSafe(SCHEMA.asStruct(), expected(1, 2), read(table.location(), + assertEqualsSafe(SCHEMA.asStruct(), expected(1, 2), read(table.location(), vectorized, "ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " + "ts < cast('2017-12-22 08:00:00+00:00' as timestamp)")); } @@ -364,7 +382,7 @@ public void testHourPartitionedTimestampFilters() { Assert.assertEquals("Should create 4 tasks for 2017-12-21: 15, 17, 21, 22", 4, tasks.length); assertEqualsSafe(SCHEMA.asStruct(), expected(8, 9, 7, 6, 5), - read(table.location(), "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); + read(table.location(), vectorized, "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); } { @@ -378,7 +396,7 @@ public void testHourPartitionedTimestampFilters() { InputPartition[] tasks = scan.planInputPartitions(); Assert.assertEquals("Should create 2 tasks for 2017-12-22: 6, 7", 2, tasks.length); - assertEqualsSafe(SCHEMA.asStruct(), expected(2, 1), read(table.location(), + assertEqualsSafe(SCHEMA.asStruct(), expected(2, 1), read(table.location(), vectorized, "ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " + "ts < cast('2017-12-22 08:00:00+00:00' as timestamp)")); } @@ -395,7 +413,7 @@ public void testFilterByNonProjectedColumn() { } assertEqualsSafe(actualProjection.asStruct(), expected, read( - unpartitioned.toString(), + unpartitioned.toString(), vectorized, "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)", "id", "data")); } @@ -410,7 +428,7 @@ public void testFilterByNonProjectedColumn() { } assertEqualsSafe(actualProjection.asStruct(), expected, read( - unpartitioned.toString(), + unpartitioned.toString(), vectorized, "ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " + "ts < cast('2017-12-22 08:00:00+00:00' as timestamp)", "id")); @@ -450,6 +468,7 @@ public void testPartitionedByIdStartsWith() { public void testUnpartitionedStartsWith() { Dataset df = spark.read() .format("iceberg") + .option("vectorization-enabled", String.valueOf(vectorized)) .load(unpartitioned.toString()); List matchedData = df.select("data") @@ -462,12 +481,11 @@ public void testUnpartitionedStartsWith() { } private static Record projectFlat(Schema projection, Record record) { - org.apache.avro.Schema avroSchema = AvroSchemaUtil.convert(projection, "test"); - Record result = new Record(avroSchema); + Record result = GenericRecord.create(projection); List fields = projection.asStruct().fields(); for (int i = 0; i < fields.size(); i += 1) { Types.NestedField field = fields.get(i); - result.put(i, record.get(field.name())); + result.set(i, record.getField(field.name())); } return result; } @@ -477,7 +495,7 @@ public static void assertEqualsUnsafe(Types.StructType struct, // TODO: match records by ID int numRecords = Math.min(expected.size(), actual.size()); for (int i = 0; i < numRecords; i += 1) { - TestHelpers.assertEqualsUnsafe(struct, expected.get(i), actual.get(i)); + GenericsHelpers.assertEqualsUnsafe(struct, expected.get(i), actual.get(i)); } Assert.assertEquals("Number of results should match expected", expected.size(), actual.size()); } @@ -487,7 +505,7 @@ public static void assertEqualsSafe(Types.StructType struct, // TODO: match records by ID int numRecords = Math.min(expected.size(), actual.size()); for (int i = 0; i < numRecords; i += 1) { - TestHelpers.assertEqualsSafe(struct, expected.get(i), actual.get(i)); + GenericsHelpers.assertEqualsSafe(struct, expected.get(i), actual.get(i)); } Assert.assertEquals("Number of results should match expected", expected.size(), actual.size()); } @@ -517,6 +535,7 @@ private Table buildPartitionedTable(String desc, PartitionSpec spec, String udf, // copy the unpartitioned table into the partitioned table to produce the partitioned data Dataset allRows = spark.read() .format("iceberg") + .option("vectorization-enabled", String.valueOf(vectorized)) .load(unpartitioned.toString()); allRows @@ -534,39 +553,41 @@ private Table buildPartitionedTable(String desc, PartitionSpec spec, String udf, return table; } - private List testRecords(org.apache.avro.Schema avroSchema) { + private List testRecords(Schema schema) { return Lists.newArrayList( - record(avroSchema, 0L, timestamp("2017-12-22T09:20:44.294658+00:00"), "junction"), - record(avroSchema, 1L, timestamp("2017-12-22T07:15:34.582910+00:00"), "alligator"), - record(avroSchema, 2L, timestamp("2017-12-22T06:02:09.243857+00:00"), "forrest"), - record(avroSchema, 3L, timestamp("2017-12-22T03:10:11.134509+00:00"), "clapping"), - record(avroSchema, 4L, timestamp("2017-12-22T00:34:00.184671+00:00"), "brush"), - record(avroSchema, 5L, timestamp("2017-12-21T22:20:08.935889+00:00"), "trap"), - record(avroSchema, 6L, timestamp("2017-12-21T21:55:30.589712+00:00"), "element"), - record(avroSchema, 7L, timestamp("2017-12-21T17:31:14.532797+00:00"), "limited"), - record(avroSchema, 8L, timestamp("2017-12-21T15:21:51.237521+00:00"), "global"), - record(avroSchema, 9L, timestamp("2017-12-21T15:02:15.230570+00:00"), "goldfish") + record(schema, 0L, parse("2017-12-22T09:20:44.294658+00:00"), "junction"), + record(schema, 1L, parse("2017-12-22T07:15:34.582910+00:00"), "alligator"), + record(schema, 2L, parse("2017-12-22T06:02:09.243857+00:00"), "forrest"), + record(schema, 3L, parse("2017-12-22T03:10:11.134509+00:00"), "clapping"), + record(schema, 4L, parse("2017-12-22T00:34:00.184671+00:00"), "brush"), + record(schema, 5L, parse("2017-12-21T22:20:08.935889+00:00"), "trap"), + record(schema, 6L, parse("2017-12-21T21:55:30.589712+00:00"), "element"), + record(schema, 7L, parse("2017-12-21T17:31:14.532797+00:00"), "limited"), + record(schema, 8L, parse("2017-12-21T15:21:51.237521+00:00"), "global"), + record(schema, 9L, parse("2017-12-21T15:02:15.230570+00:00"), "goldfish") ); } - private static List read(String table, String expr) { - return read(table, expr, "*"); + private static List read(String table, boolean vectorized, String expr) { + return read(table, vectorized, expr, "*"); } - private static List read(String table, String expr, String select0, String... selectN) { - Dataset dataset = spark.read().format("iceberg").load(table).filter(expr) + private static List read(String table, boolean vectorized, String expr, String select0, String... selectN) { + Dataset dataset = spark.read().format("iceberg") + .option("vectorization-enabled", String.valueOf(vectorized)) + .load(table).filter(expr) .select(select0, selectN); return dataset.collectAsList(); } - private static long timestamp(String timestamp) { - return Literal.of(timestamp).to(Types.TimestampType.withZone()).value(); + private static OffsetDateTime parse(String timestamp) { + return OffsetDateTime.parse(timestamp); } - private static Record record(org.apache.avro.Schema schema, Object... values) { - Record rec = new Record(schema); + private static Record record(Schema schema, Object... values) { + Record rec = GenericRecord.create(schema); for (int i = 0; i < values.length; i += 1) { - rec.put(i, values[i]); + rec.set(i, values[i]); } return rec; } diff --git a/spark3/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData3.java b/spark3/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData3.java index 83f3f322972c..3b90f61e8d64 100644 --- a/spark3/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData3.java +++ b/spark3/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData3.java @@ -20,7 +20,7 @@ package org.apache.iceberg.spark.source; public class TestIdentityPartitionData3 extends TestIdentityPartitionData { - public TestIdentityPartitionData3(String format) { - super(format); + public TestIdentityPartitionData3(String format, boolean vectorized) { + super(format, vectorized); } } diff --git a/spark3/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues3.java b/spark3/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues3.java index 63db54ebf653..9b421921d649 100644 --- a/spark3/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues3.java +++ b/spark3/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues3.java @@ -20,7 +20,7 @@ package org.apache.iceberg.spark.source; public class TestPartitionValues3 extends TestPartitionValues { - public TestPartitionValues3(String format) { - super(format); + public TestPartitionValues3(String format, boolean vectorized) { + super(format, vectorized); } }