diff --git a/api/src/main/java/org/apache/iceberg/PartitionKey.java b/api/src/main/java/org/apache/iceberg/PartitionKey.java new file mode 100644 index 000000000000..71cdb2756ed2 --- /dev/null +++ b/api/src/main/java/org/apache/iceberg/PartitionKey.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg; + +import java.io.Serializable; +import java.lang.reflect.Array; +import java.util.Arrays; +import java.util.List; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.transforms.Transform; + +/** + * A struct of partition values. + *

+ * Instances of this class can produce partition values from a data row passed to {@link #partition(StructLike)}. + */ +public class PartitionKey implements StructLike, Serializable { + + private final PartitionSpec spec; + private final int size; + private final Object[] partitionTuple; + private final Transform[] transforms; + private final Accessor[] accessors; + + @SuppressWarnings("unchecked") + public PartitionKey(PartitionSpec spec, Schema inputSchema) { + this.spec = spec; + + List fields = spec.fields(); + this.size = fields.size(); + this.partitionTuple = new Object[size]; + this.transforms = new Transform[size]; + this.accessors = (Accessor[]) Array.newInstance(Accessor.class, size); + + Schema schema = spec.schema(); + for (int i = 0; i < size; i += 1) { + PartitionField field = fields.get(i); + Accessor accessor = inputSchema.accessorForField(field.sourceId()); + Preconditions.checkArgument(accessor != null, + "Cannot build accessor for field: " + schema.findField(field.sourceId())); + this.accessors[i] = accessor; + this.transforms[i] = field.transform(); + } + } + + private PartitionKey(PartitionKey toCopy) { + this.spec = toCopy.spec; + this.size = toCopy.size; + this.partitionTuple = new Object[toCopy.partitionTuple.length]; + this.transforms = toCopy.transforms; + this.accessors = toCopy.accessors; + + System.arraycopy(toCopy.partitionTuple, 0, this.partitionTuple, 0, partitionTuple.length); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("["); + for (int i = 0; i < partitionTuple.length; i += 1) { + if (i > 0) { + sb.append(", "); + } + sb.append(partitionTuple[i]); + } + sb.append("]"); + return sb.toString(); + } + + public PartitionKey copy() { + return new PartitionKey(this); + } + + public String toPath() { + return spec.partitionToPath(this); + } + + /** + * Replace this key's partition values with the partition values for the row. + * + * @param row a StructLike row + */ + @SuppressWarnings("unchecked") + public void partition(StructLike row) { + for (int i = 0; i < partitionTuple.length; i += 1) { + Transform transform = transforms[i]; + partitionTuple[i] = transform.apply(accessors[i].get(row)); + } + } + + @Override + public int size() { + return size; + } + + @Override + public T get(int pos, Class javaClass) { + return javaClass.cast(partitionTuple[pos]); + } + + @Override + public void set(int pos, T value) { + partitionTuple[pos] = value; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (!(o instanceof PartitionKey)) { + return false; + } + + PartitionKey that = (PartitionKey) o; + return Arrays.equals(partitionTuple, that.partitionTuple); + } + + @Override + public int hashCode() { + return Arrays.hashCode(partitionTuple); + } +} diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/BaseWriter.java b/spark/src/main/java/org/apache/iceberg/spark/source/BaseWriter.java index f3a10301de36..8c41e77d0f10 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/source/BaseWriter.java +++ b/spark/src/main/java/org/apache/iceberg/spark/source/BaseWriter.java @@ -26,6 +26,7 @@ import org.apache.iceberg.DataFiles; import org.apache.iceberg.FileFormat; import org.apache.iceberg.Metrics; +import org.apache.iceberg.PartitionKey; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.encryption.EncryptedOutputFile; import org.apache.iceberg.io.FileAppender; diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/InternalRowWrapper.java b/spark/src/main/java/org/apache/iceberg/spark/source/InternalRowWrapper.java new file mode 100644 index 000000000000..ef1eb08d873c --- /dev/null +++ b/spark/src/main/java/org/apache/iceberg/spark/source/InternalRowWrapper.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.nio.ByteBuffer; +import java.util.function.BiFunction; +import java.util.stream.Stream; +import org.apache.iceberg.StructLike; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * Class to adapt a Spark {@code InternalRow} to Iceberg {@link StructLike} for uses like + * {@link org.apache.iceberg.PartitionKey#partition(StructLike)} + */ +class InternalRowWrapper implements StructLike { + private final DataType[] types; + private final BiFunction[] getters; + private InternalRow row = null; + + @SuppressWarnings("unchecked") + InternalRowWrapper(StructType rowType) { + this.types = Stream.of(rowType.fields()) + .map(StructField::dataType) + .toArray(DataType[]::new); + this.getters = Stream.of(types) + .map(InternalRowWrapper::getter) + .toArray(BiFunction[]::new); + } + + InternalRowWrapper wrap(InternalRow internalRow) { + this.row = internalRow; + return this; + } + + @Override + public int size() { + return types.length; + } + + @Override + public T get(int pos, Class javaClass) { + if (row.isNullAt(pos)) { + return null; + } else if (getters[pos] != null) { + return javaClass.cast(getters[pos].apply(row, pos)); + } + + return javaClass.cast(row.get(pos, types[pos])); + } + + @Override + public void set(int pos, T value) { + row.update(pos, value); + } + + private static BiFunction getter(DataType type) { + if (type instanceof StringType) { + return (row, pos) -> row.getUTF8String(pos).toString(); + } else if (type instanceof DecimalType) { + DecimalType decimal = (DecimalType) type; + return (row, pos) -> + row.getDecimal(pos, decimal.precision(), decimal.scale()).toJavaBigDecimal(); + } else if (type instanceof BinaryType) { + return (row, pos) -> ByteBuffer.wrap(row.getBinary(pos)); + } else if (type instanceof StructType) { + StructType structType = (StructType) type; + InternalRowWrapper nestedWrapper = new InternalRowWrapper(structType); + return (row, pos) -> nestedWrapper.wrap(row.getStruct(pos, structType.size())); + } + + return null; + } +} diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/OutputFileFactory.java b/spark/src/main/java/org/apache/iceberg/spark/source/OutputFileFactory.java index 1daf889e0512..08e66df79362 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/source/OutputFileFactory.java +++ b/spark/src/main/java/org/apache/iceberg/spark/source/OutputFileFactory.java @@ -22,6 +22,7 @@ import java.util.UUID; import java.util.concurrent.atomic.AtomicInteger; import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionKey; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.encryption.EncryptedOutputFile; import org.apache.iceberg.encryption.EncryptionManager; diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/PartitionKey.java b/spark/src/main/java/org/apache/iceberg/spark/source/PartitionKey.java deleted file mode 100644 index 292be8a9e3b5..000000000000 --- a/spark/src/main/java/org/apache/iceberg/spark/source/PartitionKey.java +++ /dev/null @@ -1,364 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iceberg.spark.source; - -import java.lang.reflect.Array; -import java.nio.ByteBuffer; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import org.apache.iceberg.PartitionField; -import org.apache.iceberg.PartitionSpec; -import org.apache.iceberg.Schema; -import org.apache.iceberg.StructLike; -import org.apache.iceberg.relocated.com.google.common.collect.Maps; -import org.apache.iceberg.spark.SparkSchemaUtil; -import org.apache.iceberg.transforms.Transform; -import org.apache.iceberg.types.Type; -import org.apache.iceberg.types.TypeUtil; -import org.apache.iceberg.types.Types; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.types.UTF8String; - -class PartitionKey implements StructLike { - - private final PartitionSpec spec; - private final int size; - private final Object[] partitionTuple; - private final Transform[] transforms; - private final Accessor[] accessors; - - @SuppressWarnings("unchecked") - PartitionKey(PartitionSpec spec, Schema inputSchema) { - this.spec = spec; - - List fields = spec.fields(); - this.size = fields.size(); - this.partitionTuple = new Object[size]; - this.transforms = new Transform[size]; - this.accessors = (Accessor[]) Array.newInstance(Accessor.class, size); - - Schema schema = spec.schema(); - Map> newAccessors = buildAccessors(inputSchema); - for (int i = 0; i < size; i += 1) { - PartitionField field = fields.get(i); - Accessor accessor = newAccessors.get(field.sourceId()); - if (accessor == null) { - throw new RuntimeException( - "Cannot build accessor for field: " + schema.findField(field.sourceId())); - } - this.accessors[i] = accessor; - this.transforms[i] = field.transform(); - } - } - - private PartitionKey(PartitionKey toCopy) { - this.spec = toCopy.spec; - this.size = toCopy.size; - this.partitionTuple = new Object[toCopy.partitionTuple.length]; - this.transforms = toCopy.transforms; - this.accessors = toCopy.accessors; - - for (int i = 0; i < partitionTuple.length; i += 1) { - this.partitionTuple[i] = defensiveCopyIfNeeded(toCopy.partitionTuple[i]); - } - } - - private Object defensiveCopyIfNeeded(Object obj) { - if (obj instanceof UTF8String) { - // bytes backing the UTF8 string might be reused - byte[] bytes = ((UTF8String) obj).getBytes(); - return UTF8String.fromBytes(Arrays.copyOf(bytes, bytes.length)); - } - return obj; - } - - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append("["); - for (int i = 0; i < partitionTuple.length; i += 1) { - if (i > 0) { - sb.append(", "); - } - sb.append(partitionTuple[i]); - } - sb.append("]"); - return sb.toString(); - } - - PartitionKey copy() { - return new PartitionKey(this); - } - - String toPath() { - return spec.partitionToPath(this); - } - - @SuppressWarnings("unchecked") - void partition(InternalRow row) { - for (int i = 0; i < partitionTuple.length; i += 1) { - Transform transform = transforms[i]; - partitionTuple[i] = transform.apply(accessors[i].get(row)); - } - } - - @Override - public int size() { - return size; - } - - @Override - @SuppressWarnings("unchecked") - public T get(int pos, Class javaClass) { - return javaClass.cast(partitionTuple[pos]); - } - - @Override - public void set(int pos, T value) { - partitionTuple[pos] = value; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } else if (!(o instanceof PartitionKey)) { - return false; - } - - PartitionKey that = (PartitionKey) o; - return Arrays.equals(partitionTuple, that.partitionTuple); - } - - @Override - public int hashCode() { - return Arrays.hashCode(partitionTuple); - } - - private interface Accessor { - Object get(T container); - } - - private static Map> buildAccessors(Schema schema) { - return TypeUtil.visit(schema, new BuildPositionAccessors()); - } - - private static Accessor newAccessor(int position, Type type) { - switch (type.typeId()) { - case STRING: - return new StringAccessor(position, SparkSchemaUtil.convert(type)); - case DECIMAL: - return new DecimalAccessor(position, SparkSchemaUtil.convert(type)); - case BINARY: - return new BytesAccessor(position, SparkSchemaUtil.convert(type)); - default: - return new PositionAccessor(position, SparkSchemaUtil.convert(type)); - } - } - - private static Accessor newAccessor(int position, boolean isOptional, Types.StructType type, - Accessor accessor) { - int size = type.fields().size(); - if (isOptional) { - // the wrapped position handles null layers - return new WrappedPositionAccessor(position, size, accessor); - } else if (accessor.getClass() == PositionAccessor.class) { - return new Position2Accessor(position, size, (PositionAccessor) accessor); - } else if (accessor instanceof Position2Accessor) { - return new Position3Accessor(position, size, (Position2Accessor) accessor); - } else { - return new WrappedPositionAccessor(position, size, accessor); - } - } - - private static class BuildPositionAccessors - extends TypeUtil.SchemaVisitor>> { - @Override - public Map> schema( - Schema schema, Map> structResult) { - return structResult; - } - - @Override - public Map> struct( - Types.StructType struct, List>> fieldResults) { - Map> accessors = Maps.newHashMap(); - List fields = struct.fields(); - for (int i = 0; i < fieldResults.size(); i += 1) { - Types.NestedField field = fields.get(i); - Map> result = fieldResults.get(i); - if (result != null) { - for (Map.Entry> entry : result.entrySet()) { - accessors.put(entry.getKey(), newAccessor(i, field.isOptional(), - field.type().asNestedType().asStructType(), entry.getValue())); - } - } else { - accessors.put(field.fieldId(), newAccessor(i, field.type())); - } - } - - if (accessors.isEmpty()) { - return null; - } - - return accessors; - } - - @Override - public Map> field( - Types.NestedField field, Map> fieldResult) { - return fieldResult; - } - } - - private static class PositionAccessor implements Accessor { - private final DataType type; - private int position; - - private PositionAccessor(int position, DataType type) { - this.position = position; - this.type = type; - } - - @Override - public Object get(InternalRow row) { - if (row.isNullAt(position)) { - return null; - } - return row.get(position, type); - } - - DataType type() { - return type; - } - - int position() { - return position; - } - } - - private static class StringAccessor extends PositionAccessor { - private StringAccessor(int position, DataType type) { - super(position, type); - } - - @Override - public Object get(InternalRow row) { - if (row.isNullAt(position())) { - return null; - } - return row.get(position(), type()).toString(); - } - } - - private static class DecimalAccessor extends PositionAccessor { - private DecimalAccessor(int position, DataType type) { - super(position, type); - } - - @Override - public Object get(InternalRow row) { - if (row.isNullAt(position())) { - return null; - } - return ((Decimal) row.get(position(), type())).toJavaBigDecimal(); - } - } - - private static class BytesAccessor extends PositionAccessor { - private BytesAccessor(int position, DataType type) { - super(position, type); - } - - @Override - public Object get(InternalRow row) { - if (row.isNullAt(position())) { - return null; - } - return ByteBuffer.wrap((byte[]) row.get(position(), type())); - } - } - - private static class Position2Accessor implements Accessor { - private final int p0; - private final int size0; - private final int p1; - private final DataType type; - - private Position2Accessor(int position, int size, PositionAccessor wrapped) { - this.p0 = position; - this.size0 = size; - this.p1 = wrapped.position; - this.type = wrapped.type; - } - - @Override - public Object get(InternalRow row) { - return row.getStruct(p0, size0).get(p1, type); - } - } - - private static class Position3Accessor implements Accessor { - private final int p0; - private final int size0; - private final int p1; - private final int size1; - private final int p2; - private final DataType type; - - private Position3Accessor(int position, int size, Position2Accessor wrapped) { - this.p0 = position; - this.size0 = size; - this.p1 = wrapped.p0; - this.size1 = wrapped.size0; - this.p2 = wrapped.p1; - this.type = wrapped.type; - } - - @Override - public Object get(InternalRow row) { - return row.getStruct(p0, size0).getStruct(p1, size1).get(p2, type); - } - } - - private static class WrappedPositionAccessor implements Accessor { - private final int position; - private final int size; - private final Accessor accessor; - - private WrappedPositionAccessor(int position, int size, Accessor accessor) { - this.position = position; - this.size = size; - this.accessor = accessor; - } - - @Override - public Object get(InternalRow row) { - InternalRow inner = row.getStruct(position, size); - if (inner != null) { - return accessor.get(inner); - } - return null; - } - } -} diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/PartitionedWriter.java b/spark/src/main/java/org/apache/iceberg/spark/source/PartitionedWriter.java index 7cc0c9f90555..0ead766f38c7 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/source/PartitionedWriter.java +++ b/spark/src/main/java/org/apache/iceberg/spark/source/PartitionedWriter.java @@ -22,11 +22,13 @@ import java.io.IOException; import java.util.Set; import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionKey; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Schema; import org.apache.iceberg.io.FileIO; import org.apache.iceberg.relocated.com.google.common.collect.Iterables; import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkSchemaUtil; import org.apache.spark.sql.catalyst.InternalRow; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,17 +37,19 @@ class PartitionedWriter extends BaseWriter { private static final Logger LOG = LoggerFactory.getLogger(PartitionedWriter.class); private final PartitionKey key; + private final InternalRowWrapper wrapper; private final Set completedPartitions = Sets.newHashSet(); PartitionedWriter(PartitionSpec spec, FileFormat format, SparkAppenderFactory appenderFactory, OutputFileFactory fileFactory, FileIO io, long targetFileSize, Schema writeSchema) { super(spec, format, appenderFactory, fileFactory, io, targetFileSize); this.key = new PartitionKey(spec, writeSchema); + this.wrapper = new InternalRowWrapper(SparkSchemaUtil.convert(writeSchema)); } @Override public void write(InternalRow row) throws IOException { - key.partition(row); + key.partition(wrapper.wrap(row)); PartitionKey currentKey = getCurrentKey(); if (!key.equals(currentKey)) {