diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriter.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriter.java new file mode 100644 index 000000000000..ef0ccf28e95f --- /dev/null +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriter.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import org.apache.orc.storage.ql.exec.vector.ColumnVector; +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters; + +interface SparkOrcValueWriter { + + /** + * Take a value from the data and add it to the ORC output. + * + * @param rowId the row id in the ColumnVector. + * @param column the column number. + * @param data the data value to write. + * @param output the ColumnVector to put the value into. + */ + default void write(int rowId, int column, SpecializedGetters data, ColumnVector output) { + if (data.isNullAt(column)) { + output.noNulls = false; + output.isNull[rowId] = true; + } else { + output.isNull[rowId] = false; + nonNullWrite(rowId, column, data, output); + } + } + + void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output); +} diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java new file mode 100644 index 000000000000..8bb0f53f83cb --- /dev/null +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import org.apache.orc.storage.common.type.HiveDecimal; +import org.apache.orc.storage.ql.exec.vector.BytesColumnVector; +import org.apache.orc.storage.ql.exec.vector.ColumnVector; +import org.apache.orc.storage.ql.exec.vector.DecimalColumnVector; +import org.apache.orc.storage.ql.exec.vector.DoubleColumnVector; +import org.apache.orc.storage.ql.exec.vector.ListColumnVector; +import org.apache.orc.storage.ql.exec.vector.LongColumnVector; +import org.apache.orc.storage.ql.exec.vector.MapColumnVector; +import org.apache.orc.storage.ql.exec.vector.TimestampColumnVector; +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; + +class SparkOrcValueWriters { + private SparkOrcValueWriters() { + } + + static SparkOrcValueWriter booleans() { + return BooleanWriter.INSTANCE; + } + + static SparkOrcValueWriter bytes() { + return ByteWriter.INSTANCE; + } + + static SparkOrcValueWriter shorts() { + return ShortWriter.INSTANCE; + } + + static SparkOrcValueWriter ints() { + return IntWriter.INSTANCE; + } + + static SparkOrcValueWriter longs() { + return LongWriter.INSTANCE; + } + + static SparkOrcValueWriter floats() { + return FloatWriter.INSTANCE; + } + + static SparkOrcValueWriter doubles() { + return DoubleWriter.INSTANCE; + } + + static SparkOrcValueWriter byteArrays() { + return BytesWriter.INSTANCE; + } + + static SparkOrcValueWriter strings() { + return StringWriter.INSTANCE; + } + + static SparkOrcValueWriter timestampTz() { + return TimestampTzWriter.INSTANCE; + } + + static SparkOrcValueWriter decimal(int precision, int scale) { + if (precision <= 18) { + return new Decimal18Writer(precision, scale); + } else { + return new Decimal38Writer(precision, scale); + } + } + + static SparkOrcValueWriter list(SparkOrcValueWriter element) { + return new ListWriter(element); + } + + static SparkOrcValueWriter map(SparkOrcValueWriter keyWriter, SparkOrcValueWriter valueWriter) { + return new MapWriter(keyWriter, valueWriter); + } + + private static class BooleanWriter implements SparkOrcValueWriter { + private static final BooleanWriter INSTANCE = new BooleanWriter(); + + @Override + public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + ((LongColumnVector) output).vector[rowId] = data.getBoolean(column) ? 1 : 0; + } + } + + private static class ByteWriter implements SparkOrcValueWriter { + private static final ByteWriter INSTANCE = new ByteWriter(); + + @Override + public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + ((LongColumnVector) output).vector[rowId] = data.getByte(column); + } + } + + private static class ShortWriter implements SparkOrcValueWriter { + private static final ShortWriter INSTANCE = new ShortWriter(); + + @Override + public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + ((LongColumnVector) output).vector[rowId] = data.getShort(column); + } + } + + private static class IntWriter implements SparkOrcValueWriter { + private static final IntWriter INSTANCE = new IntWriter(); + + @Override + public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + ((LongColumnVector) output).vector[rowId] = data.getInt(column); + } + } + + private static class LongWriter implements SparkOrcValueWriter { + private static final LongWriter INSTANCE = new LongWriter(); + + @Override + public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + ((LongColumnVector) output).vector[rowId] = data.getLong(column); + } + } + + private static class FloatWriter implements SparkOrcValueWriter { + private static final FloatWriter INSTANCE = new FloatWriter(); + + @Override + public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + ((DoubleColumnVector) output).vector[rowId] = data.getFloat(column); + } + } + + private static class DoubleWriter implements SparkOrcValueWriter { + private static final DoubleWriter INSTANCE = new DoubleWriter(); + + @Override + public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + ((DoubleColumnVector) output).vector[rowId] = data.getDouble(column); + } + } + + private static class BytesWriter implements SparkOrcValueWriter { + private static final BytesWriter INSTANCE = new BytesWriter(); + + @Override + public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + // getBinary always makes a copy, so we don't need to worry about it + // being changed behind our back. + byte[] value = data.getBinary(column); + ((BytesColumnVector) output).setRef(rowId, value, 0, value.length); + } + } + + private static class StringWriter implements SparkOrcValueWriter { + private static final StringWriter INSTANCE = new StringWriter(); + + @Override + public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + byte[] value = data.getUTF8String(column).getBytes(); + ((BytesColumnVector) output).setRef(rowId, value, 0, value.length); + } + } + + private static class TimestampTzWriter implements SparkOrcValueWriter { + private static final TimestampTzWriter INSTANCE = new TimestampTzWriter(); + + @Override + public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + TimestampColumnVector cv = (TimestampColumnVector) output; + long micros = data.getLong(column); + cv.time[rowId] = micros / 1_000; // millis + cv.nanos[rowId] = (int) (micros % 1_000_000) * 1_000; // nanos + } + } + + private static class Decimal18Writer implements SparkOrcValueWriter { + private final int precision; + private final int scale; + + Decimal18Writer(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } + + @Override + public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + ((DecimalColumnVector) output).vector[rowId].setFromLongAndScale( + data.getDecimal(column, precision, scale).toUnscaledLong(), scale); + } + } + + private static class Decimal38Writer implements SparkOrcValueWriter { + private final int precision; + private final int scale; + + Decimal38Writer(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } + + @Override + public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + ((DecimalColumnVector) output).vector[rowId].set( + HiveDecimal.create(data.getDecimal(column, precision, scale) + .toJavaBigDecimal())); + } + } + + private static class ListWriter implements SparkOrcValueWriter { + private final SparkOrcValueWriter writer; + + ListWriter(SparkOrcValueWriter writer) { + this.writer = writer; + } + + @Override + public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + ArrayData value = data.getArray(column); + ListColumnVector cv = (ListColumnVector) output; + // record the length and start of the list elements + cv.lengths[rowId] = value.numElements(); + cv.offsets[rowId] = cv.childCount; + cv.childCount += cv.lengths[rowId]; + // make sure the child is big enough + cv.child.ensureSize(cv.childCount, true); + // Add each element + for (int e = 0; e < cv.lengths[rowId]; ++e) { + writer.write((int) (e + cv.offsets[rowId]), e, value, cv.child); + } + } + } + + private static class MapWriter implements SparkOrcValueWriter { + private final SparkOrcValueWriter keyWriter; + private final SparkOrcValueWriter valueWriter; + + MapWriter(SparkOrcValueWriter keyWriter, SparkOrcValueWriter valueWriter) { + this.keyWriter = keyWriter; + this.valueWriter = valueWriter; + } + + @Override + public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + MapData map = data.getMap(column); + ArrayData key = map.keyArray(); + ArrayData value = map.valueArray(); + MapColumnVector cv = (MapColumnVector) output; + // record the length and start of the list elements + cv.lengths[rowId] = value.numElements(); + cv.offsets[rowId] = cv.childCount; + cv.childCount += cv.lengths[rowId]; + // make sure the child is big enough + cv.keys.ensureSize(cv.childCount, true); + cv.values.ensureSize(cv.childCount, true); + // Add each element + for (int e = 0; e < cv.lengths[rowId]; ++e) { + int pos = (int) (e + cv.offsets[rowId]); + keyWriter.write(pos, e, key, cv.keys); + valueWriter.write(pos, e, value, cv.values); + } + } + } +} diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java index 0361fdc1c0c8..4508a102d447 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java @@ -20,23 +20,18 @@ package org.apache.iceberg.spark.data; import java.util.List; +import org.apache.iceberg.Schema; import org.apache.iceberg.orc.OrcRowWriter; +import org.apache.iceberg.orc.OrcSchemaWithTypeVisitor; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; import org.apache.orc.TypeDescription; -import org.apache.orc.storage.common.type.HiveDecimal; -import org.apache.orc.storage.ql.exec.vector.BytesColumnVector; import org.apache.orc.storage.ql.exec.vector.ColumnVector; -import org.apache.orc.storage.ql.exec.vector.DecimalColumnVector; -import org.apache.orc.storage.ql.exec.vector.DoubleColumnVector; -import org.apache.orc.storage.ql.exec.vector.ListColumnVector; -import org.apache.orc.storage.ql.exec.vector.LongColumnVector; -import org.apache.orc.storage.ql.exec.vector.MapColumnVector; import org.apache.orc.storage.ql.exec.vector.StructColumnVector; -import org.apache.orc.storage.ql.exec.vector.TimestampColumnVector; import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.SpecializedGetters; -import org.apache.spark.sql.catalyst.util.ArrayData; -import org.apache.spark.sql.catalyst.util.MapData; /** * This class acts as an adaptor from an OrcFileAppender to a @@ -44,375 +39,102 @@ */ public class SparkOrcWriter implements OrcRowWriter { - private final Converter[] converters; + private final SparkOrcValueWriter writer; - public SparkOrcWriter(TypeDescription schema) { - converters = buildConverters(schema); + public SparkOrcWriter(Schema iSchema, TypeDescription orcSchema) { + Preconditions.checkArgument(orcSchema.getCategory() == TypeDescription.Category.STRUCT, + "Top level must be a struct " + orcSchema); + + writer = OrcSchemaWithTypeVisitor.visit(iSchema, orcSchema, new WriteBuilder()); } @Override public void write(InternalRow value, VectorizedRowBatch output) { - int row = output.size++; - for (int c = 0; c < converters.length; ++c) { - converters[c].addValue(row, c, value, output.cols[c]); - } - } - - /** - * The interface for the conversion from Spark's SpecializedGetters to - * ORC's ColumnVectors. - */ - interface Converter { - /** - * Take a value from the Spark data value and add it to the ORC output. - * @param rowId the row in the ColumnVector - * @param column either the column number or element number - * @param data either an InternalRow or ArrayData - * @param output the ColumnVector to put the value into - */ - void addValue(int rowId, int column, SpecializedGetters data, - ColumnVector output); - } - - static class BooleanConverter implements Converter { - @Override - public void addValue(int rowId, int column, SpecializedGetters data, - ColumnVector output) { - if (data.isNullAt(column)) { - output.noNulls = false; - output.isNull[rowId] = true; - } else { - output.isNull[rowId] = false; - ((LongColumnVector) output).vector[rowId] = data.getBoolean(column) ? 1 : 0; - } - } - } - - static class ByteConverter implements Converter { - @Override - public void addValue(int rowId, int column, SpecializedGetters data, - ColumnVector output) { - if (data.isNullAt(column)) { - output.noNulls = false; - output.isNull[rowId] = true; - } else { - output.isNull[rowId] = false; - ((LongColumnVector) output).vector[rowId] = data.getByte(column); - } - } - } - - static class ShortConverter implements Converter { - @Override - public void addValue(int rowId, int column, SpecializedGetters data, - ColumnVector output) { - if (data.isNullAt(column)) { - output.noNulls = false; - output.isNull[rowId] = true; - } else { - output.isNull[rowId] = false; - ((LongColumnVector) output).vector[rowId] = data.getShort(column); - } - } - } - - static class IntConverter implements Converter { - @Override - public void addValue(int rowId, int column, SpecializedGetters data, - ColumnVector output) { - if (data.isNullAt(column)) { - output.noNulls = false; - output.isNull[rowId] = true; - } else { - output.isNull[rowId] = false; - ((LongColumnVector) output).vector[rowId] = data.getInt(column); - } - } - } - - static class LongConverter implements Converter { - @Override - public void addValue(int rowId, int column, SpecializedGetters data, - ColumnVector output) { - if (data.isNullAt(column)) { - output.noNulls = false; - output.isNull[rowId] = true; - } else { - output.isNull[rowId] = false; - ((LongColumnVector) output).vector[rowId] = data.getLong(column); - } - } - } - - static class FloatConverter implements Converter { - @Override - public void addValue(int rowId, int column, SpecializedGetters data, - ColumnVector output) { - if (data.isNullAt(column)) { - output.noNulls = false; - output.isNull[rowId] = true; - } else { - output.isNull[rowId] = false; - ((DoubleColumnVector) output).vector[rowId] = data.getFloat(column); - } - } - } - - static class DoubleConverter implements Converter { - @Override - public void addValue(int rowId, int column, SpecializedGetters data, - ColumnVector output) { - if (data.isNullAt(column)) { - output.noNulls = false; - output.isNull[rowId] = true; - } else { - output.isNull[rowId] = false; - ((DoubleColumnVector) output).vector[rowId] = data.getDouble(column); - } - } - } - - static class StringConverter implements Converter { - @Override - public void addValue(int rowId, int column, SpecializedGetters data, - ColumnVector output) { - if (data.isNullAt(column)) { - output.noNulls = false; - output.isNull[rowId] = true; - } else { - output.isNull[rowId] = false; - byte[] value = data.getUTF8String(column).getBytes(); - ((BytesColumnVector) output).setRef(rowId, value, 0, value.length); - } - } - } + Preconditions.checkArgument(writer instanceof StructWriter, "writer must be StructWriter"); - static class BytesConverter implements Converter { - @Override - public void addValue(int rowId, int column, SpecializedGetters data, - ColumnVector output) { - if (data.isNullAt(column)) { - output.noNulls = false; - output.isNull[rowId] = true; - } else { - output.isNull[rowId] = false; - // getBinary always makes a copy, so we don't need to worry about it - // being changed behind our back. - byte[] value = data.getBinary(column); - ((BytesColumnVector) output).setRef(rowId, value, 0, value.length); - } + int row = output.size; + output.size += 1; + List writers = ((StructWriter) writer).writers(); + for (int c = 0; c < writers.size(); c++) { + SparkOrcValueWriter child = writers.get(c); + child.write(row, c, value, output.cols[c]); } } - static class TimestampTzConverter implements Converter { - @Override - public void addValue(int rowId, int column, SpecializedGetters data, - ColumnVector output) { - if (data.isNullAt(column)) { - output.noNulls = false; - output.isNull[rowId] = true; - } else { - output.isNull[rowId] = false; - TimestampColumnVector cv = (TimestampColumnVector) output; - long micros = data.getLong(column); - cv.time[rowId] = micros / 1_000; // millis - cv.nanos[rowId] = (int) (micros % 1_000_000) * 1_000; // nanos - } - } - } - - static class Decimal18Converter implements Converter { - private final int precision; - private final int scale; - - Decimal18Converter(TypeDescription schema) { - precision = schema.getPrecision(); - scale = schema.getScale(); + private static class WriteBuilder extends OrcSchemaWithTypeVisitor { + private WriteBuilder() { } @Override - public void addValue(int rowId, int column, SpecializedGetters data, - ColumnVector output) { - if (data.isNullAt(column)) { - output.noNulls = false; - output.isNull[rowId] = true; - } else { - output.isNull[rowId] = false; - ((DecimalColumnVector) output).vector[rowId].setFromLongAndScale( - data.getDecimal(column, precision, scale).toUnscaledLong(), scale); - } - } - } - - static class Decimal38Converter implements Converter { - private final int precision; - private final int scale; - - Decimal38Converter(TypeDescription schema) { - precision = schema.getPrecision(); - scale = schema.getScale(); + public SparkOrcValueWriter record(Types.StructType iStruct, TypeDescription record, + List names, List fields) { + return new StructWriter(fields); } @Override - public void addValue(int rowId, int column, SpecializedGetters data, - ColumnVector output) { - if (data.isNullAt(column)) { - output.noNulls = false; - output.isNull[rowId] = true; - } else { - output.isNull[rowId] = false; - ((DecimalColumnVector) output).vector[rowId].set( - HiveDecimal.create(data.getDecimal(column, precision, scale) - .toJavaBigDecimal())); - } - } - } - - static class StructConverter implements Converter { - private final Converter[] children; - - StructConverter(TypeDescription schema) { - children = new Converter[schema.getChildren().size()]; - for (int c = 0; c < children.length; ++c) { - children[c] = buildConverter(schema.getChildren().get(c)); - } + public SparkOrcValueWriter list(Types.ListType iList, TypeDescription array, + SparkOrcValueWriter element) { + return SparkOrcValueWriters.list(element); } @Override - public void addValue(int rowId, int column, SpecializedGetters data, - ColumnVector output) { - if (data.isNullAt(column)) { - output.noNulls = false; - output.isNull[rowId] = true; - } else { - output.isNull[rowId] = false; - InternalRow value = data.getStruct(column, children.length); - StructColumnVector cv = (StructColumnVector) output; - for (int c = 0; c < children.length; ++c) { - children[c].addValue(rowId, c, value, cv.fields[c]); - } - } - } - } - - static class ListConverter implements Converter { - private final Converter children; - - ListConverter(TypeDescription schema) { - children = buildConverter(schema.getChildren().get(0)); + public SparkOrcValueWriter map(Types.MapType iMap, TypeDescription map, + SparkOrcValueWriter key, SparkOrcValueWriter value) { + return SparkOrcValueWriters.map(key, value); } @Override - public void addValue(int rowId, int column, SpecializedGetters data, - ColumnVector output) { - if (data.isNullAt(column)) { - output.noNulls = false; - output.isNull[rowId] = true; - } else { - output.isNull[rowId] = false; - ArrayData value = data.getArray(column); - ListColumnVector cv = (ListColumnVector) output; - // record the length and start of the list elements - cv.lengths[rowId] = value.numElements(); - cv.offsets[rowId] = cv.childCount; - cv.childCount += cv.lengths[rowId]; - // make sure the child is big enough - cv.child.ensureSize(cv.childCount, true); - // Add each element - for (int e = 0; e < cv.lengths[rowId]; ++e) { - children.addValue((int) (e + cv.offsets[rowId]), e, value, cv.child); - } - } - } - } - - static class MapConverter implements Converter { - private final Converter keyConverter; - private final Converter valueConverter; - - MapConverter(TypeDescription schema) { - keyConverter = buildConverter(schema.getChildren().get(0)); - valueConverter = buildConverter(schema.getChildren().get(1)); + public SparkOrcValueWriter primitive(Type.PrimitiveType iPrimitive, TypeDescription primitive) { + switch (primitive.getCategory()) { + case BOOLEAN: + return SparkOrcValueWriters.booleans(); + case BYTE: + return SparkOrcValueWriters.bytes(); + case SHORT: + return SparkOrcValueWriters.shorts(); + case DATE: + case INT: + return SparkOrcValueWriters.ints(); + case LONG: + return SparkOrcValueWriters.longs(); + case FLOAT: + return SparkOrcValueWriters.floats(); + case DOUBLE: + return SparkOrcValueWriters.doubles(); + case BINARY: + return SparkOrcValueWriters.byteArrays(); + case STRING: + case CHAR: + case VARCHAR: + return SparkOrcValueWriters.strings(); + case DECIMAL: + return SparkOrcValueWriters.decimal(primitive.getPrecision(), primitive.getScale()); + case TIMESTAMP_INSTANT: + return SparkOrcValueWriters.timestampTz(); + default: + throw new IllegalArgumentException("Unhandled type " + primitive); + } + } + } + + private static class StructWriter implements SparkOrcValueWriter { + private final List writers; + + StructWriter(List writers) { + this.writers = writers; + } + + List writers() { + return writers; } @Override - public void addValue(int rowId, int column, SpecializedGetters data, - ColumnVector output) { - if (data.isNullAt(column)) { - output.noNulls = false; - output.isNull[rowId] = true; - } else { - output.isNull[rowId] = false; - MapData map = data.getMap(column); - ArrayData key = map.keyArray(); - ArrayData value = map.valueArray(); - MapColumnVector cv = (MapColumnVector) output; - // record the length and start of the list elements - cv.lengths[rowId] = value.numElements(); - cv.offsets[rowId] = cv.childCount; - cv.childCount += cv.lengths[rowId]; - // make sure the child is big enough - cv.keys.ensureSize(cv.childCount, true); - cv.values.ensureSize(cv.childCount, true); - // Add each element - for (int e = 0; e < cv.lengths[rowId]; ++e) { - int pos = (int) (e + cv.offsets[rowId]); - keyConverter.addValue(pos, e, key, cv.keys); - valueConverter.addValue(pos, e, value, cv.values); - } + public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + InternalRow value = data.getStruct(column, writers.size()); + StructColumnVector cv = (StructColumnVector) output; + for (int c = 0; c < writers.size(); ++c) { + writers.get(c).write(rowId, c, value, cv.fields[c]); } } } - - private static Converter buildConverter(TypeDescription schema) { - switch (schema.getCategory()) { - case BOOLEAN: - return new BooleanConverter(); - case BYTE: - return new ByteConverter(); - case SHORT: - return new ShortConverter(); - case DATE: - case INT: - return new IntConverter(); - case LONG: - return new LongConverter(); - case FLOAT: - return new FloatConverter(); - case DOUBLE: - return new DoubleConverter(); - case BINARY: - return new BytesConverter(); - case STRING: - case CHAR: - case VARCHAR: - return new StringConverter(); - case DECIMAL: - return schema.getPrecision() <= 18 ? - new Decimal18Converter(schema) : - new Decimal38Converter(schema); - case TIMESTAMP_INSTANT: - return new TimestampTzConverter(); - case STRUCT: - return new StructConverter(schema); - case LIST: - return new ListConverter(schema); - case MAP: - return new MapConverter(schema); - } - throw new IllegalArgumentException("Unhandled type " + schema); - } - - private static Converter[] buildConverters(TypeDescription schema) { - if (schema.getCategory() != TypeDescription.Category.STRUCT) { - throw new IllegalArgumentException("Top level must be a struct " + schema); - } - List children = schema.getChildren(); - Converter[] result = new Converter[children.size()]; - for (int c = 0; c < children.size(); ++c) { - result[c] = buildConverter(children.get(c)); - } - return result; - } - } diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/SparkAppenderFactory.java b/spark/src/main/java/org/apache/iceberg/spark/source/SparkAppenderFactory.java index d7b271e82396..e29c6eb319dc 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/source/SparkAppenderFactory.java +++ b/spark/src/main/java/org/apache/iceberg/spark/source/SparkAppenderFactory.java @@ -70,7 +70,7 @@ public FileAppender newAppender(OutputFile file, FileFormat fileFor case ORC: return ORC.write(file) - .createWriterFunc((schema, typeDesc) -> new SparkOrcWriter(typeDesc)) + .createWriterFunc(SparkOrcWriter::new) .setAll(properties) .schema(writeSchema) .overwrite() diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/TestOrcWrite.java b/spark/src/test/java/org/apache/iceberg/spark/data/TestOrcWrite.java index 2c514521da80..7cf9b9c736c6 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/data/TestOrcWrite.java +++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestOrcWrite.java @@ -50,7 +50,7 @@ public void splitOffsets() throws IOException { Iterable rows = RandomData.generateSpark(SCHEMA, 1, 0L); FileAppender writer = ORC.write(Files.localOutput(testFile)) - .createWriterFunc((schema, typeDesc) -> new SparkOrcWriter(typeDesc)) + .createWriterFunc(SparkOrcWriter::new) .schema(SCHEMA) .build(); diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReadMetadataColumns.java b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReadMetadataColumns.java index fdb378335890..03ea3c443df9 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReadMetadataColumns.java +++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReadMetadataColumns.java @@ -119,7 +119,7 @@ public void writeFile() throws IOException { Assert.assertTrue("Delete should succeed", testFile.delete()); try (FileAppender writer = ORC.write(Files.localOutput(testFile)) - .createWriterFunc((icebergSchema, typeDesc) -> new SparkOrcWriter(typeDesc)) + .createWriterFunc(SparkOrcWriter::new) .schema(DATA_SCHEMA) // write in such a way that the file contains 10 stripes each with 100 rows .config("iceberg.orc.vectorbatch.size", "100") diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java index 5822c2ebe347..5042d1cc1338 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java +++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java @@ -67,7 +67,7 @@ private void writeAndValidateRecords(Schema schema, Iterable expect Assert.assertTrue("Delete should succeed", testFile.delete()); try (FileAppender writer = ORC.write(Files.localOutput(testFile)) - .createWriterFunc((icebergSchema, typeDesc) -> new SparkOrcWriter(typeDesc)) + .createWriterFunc(SparkOrcWriter::new) .schema(schema) .build()) { writer.addAll(expected);