diff --git a/api/src/main/java/org/apache/iceberg/PartitionKey.java b/api/src/main/java/org/apache/iceberg/PartitionKey.java
new file mode 100644
index 000000000000..71cdb2756ed2
--- /dev/null
+++ b/api/src/main/java/org/apache/iceberg/PartitionKey.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg;
+
+import java.io.Serializable;
+import java.lang.reflect.Array;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.transforms.Transform;
+
+/**
+ * A struct of partition values.
+ *
+ * Instances of this class can produce partition values from a data row passed to {@link #partition(StructLike)}.
+ */
+public class PartitionKey implements StructLike, Serializable {
+
+ private final PartitionSpec spec;
+ private final int size;
+ private final Object[] partitionTuple;
+ private final Transform[] transforms;
+ private final Accessor[] accessors;
+
+ @SuppressWarnings("unchecked")
+ public PartitionKey(PartitionSpec spec, Schema inputSchema) {
+ this.spec = spec;
+
+ List fields = spec.fields();
+ this.size = fields.size();
+ this.partitionTuple = new Object[size];
+ this.transforms = new Transform[size];
+ this.accessors = (Accessor[]) Array.newInstance(Accessor.class, size);
+
+ Schema schema = spec.schema();
+ for (int i = 0; i < size; i += 1) {
+ PartitionField field = fields.get(i);
+ Accessor accessor = inputSchema.accessorForField(field.sourceId());
+ Preconditions.checkArgument(accessor != null,
+ "Cannot build accessor for field: " + schema.findField(field.sourceId()));
+ this.accessors[i] = accessor;
+ this.transforms[i] = field.transform();
+ }
+ }
+
+ private PartitionKey(PartitionKey toCopy) {
+ this.spec = toCopy.spec;
+ this.size = toCopy.size;
+ this.partitionTuple = new Object[toCopy.partitionTuple.length];
+ this.transforms = toCopy.transforms;
+ this.accessors = toCopy.accessors;
+
+ System.arraycopy(toCopy.partitionTuple, 0, this.partitionTuple, 0, partitionTuple.length);
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append("[");
+ for (int i = 0; i < partitionTuple.length; i += 1) {
+ if (i > 0) {
+ sb.append(", ");
+ }
+ sb.append(partitionTuple[i]);
+ }
+ sb.append("]");
+ return sb.toString();
+ }
+
+ public PartitionKey copy() {
+ return new PartitionKey(this);
+ }
+
+ public String toPath() {
+ return spec.partitionToPath(this);
+ }
+
+ /**
+ * Replace this key's partition values with the partition values for the row.
+ *
+ * @param row a StructLike row
+ */
+ @SuppressWarnings("unchecked")
+ public void partition(StructLike row) {
+ for (int i = 0; i < partitionTuple.length; i += 1) {
+ Transform transform = transforms[i];
+ partitionTuple[i] = transform.apply(accessors[i].get(row));
+ }
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public T get(int pos, Class javaClass) {
+ return javaClass.cast(partitionTuple[pos]);
+ }
+
+ @Override
+ public void set(int pos, T value) {
+ partitionTuple[pos] = value;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ } else if (!(o instanceof PartitionKey)) {
+ return false;
+ }
+
+ PartitionKey that = (PartitionKey) o;
+ return Arrays.equals(partitionTuple, that.partitionTuple);
+ }
+
+ @Override
+ public int hashCode() {
+ return Arrays.hashCode(partitionTuple);
+ }
+}
diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/BaseWriter.java b/spark/src/main/java/org/apache/iceberg/spark/source/BaseWriter.java
index f3a10301de36..8c41e77d0f10 100644
--- a/spark/src/main/java/org/apache/iceberg/spark/source/BaseWriter.java
+++ b/spark/src/main/java/org/apache/iceberg/spark/source/BaseWriter.java
@@ -26,6 +26,7 @@
import org.apache.iceberg.DataFiles;
import org.apache.iceberg.FileFormat;
import org.apache.iceberg.Metrics;
+import org.apache.iceberg.PartitionKey;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.encryption.EncryptedOutputFile;
import org.apache.iceberg.io.FileAppender;
diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/InternalRowWrapper.java b/spark/src/main/java/org/apache/iceberg/spark/source/InternalRowWrapper.java
new file mode 100644
index 000000000000..ef1eb08d873c
--- /dev/null
+++ b/spark/src/main/java/org/apache/iceberg/spark/source/InternalRowWrapper.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.source;
+
+import java.nio.ByteBuffer;
+import java.util.function.BiFunction;
+import java.util.stream.Stream;
+import org.apache.iceberg.StructLike;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.BinaryType;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DecimalType;
+import org.apache.spark.sql.types.StringType;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+/**
+ * Class to adapt a Spark {@code InternalRow} to Iceberg {@link StructLike} for uses like
+ * {@link org.apache.iceberg.PartitionKey#partition(StructLike)}
+ */
+class InternalRowWrapper implements StructLike {
+ private final DataType[] types;
+ private final BiFunction[] getters;
+ private InternalRow row = null;
+
+ @SuppressWarnings("unchecked")
+ InternalRowWrapper(StructType rowType) {
+ this.types = Stream.of(rowType.fields())
+ .map(StructField::dataType)
+ .toArray(DataType[]::new);
+ this.getters = Stream.of(types)
+ .map(InternalRowWrapper::getter)
+ .toArray(BiFunction[]::new);
+ }
+
+ InternalRowWrapper wrap(InternalRow internalRow) {
+ this.row = internalRow;
+ return this;
+ }
+
+ @Override
+ public int size() {
+ return types.length;
+ }
+
+ @Override
+ public T get(int pos, Class javaClass) {
+ if (row.isNullAt(pos)) {
+ return null;
+ } else if (getters[pos] != null) {
+ return javaClass.cast(getters[pos].apply(row, pos));
+ }
+
+ return javaClass.cast(row.get(pos, types[pos]));
+ }
+
+ @Override
+ public void set(int pos, T value) {
+ row.update(pos, value);
+ }
+
+ private static BiFunction getter(DataType type) {
+ if (type instanceof StringType) {
+ return (row, pos) -> row.getUTF8String(pos).toString();
+ } else if (type instanceof DecimalType) {
+ DecimalType decimal = (DecimalType) type;
+ return (row, pos) ->
+ row.getDecimal(pos, decimal.precision(), decimal.scale()).toJavaBigDecimal();
+ } else if (type instanceof BinaryType) {
+ return (row, pos) -> ByteBuffer.wrap(row.getBinary(pos));
+ } else if (type instanceof StructType) {
+ StructType structType = (StructType) type;
+ InternalRowWrapper nestedWrapper = new InternalRowWrapper(structType);
+ return (row, pos) -> nestedWrapper.wrap(row.getStruct(pos, structType.size()));
+ }
+
+ return null;
+ }
+}
diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/OutputFileFactory.java b/spark/src/main/java/org/apache/iceberg/spark/source/OutputFileFactory.java
index 1daf889e0512..08e66df79362 100644
--- a/spark/src/main/java/org/apache/iceberg/spark/source/OutputFileFactory.java
+++ b/spark/src/main/java/org/apache/iceberg/spark/source/OutputFileFactory.java
@@ -22,6 +22,7 @@
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.iceberg.FileFormat;
+import org.apache.iceberg.PartitionKey;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.encryption.EncryptedOutputFile;
import org.apache.iceberg.encryption.EncryptionManager;
diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/PartitionKey.java b/spark/src/main/java/org/apache/iceberg/spark/source/PartitionKey.java
deleted file mode 100644
index 292be8a9e3b5..000000000000
--- a/spark/src/main/java/org/apache/iceberg/spark/source/PartitionKey.java
+++ /dev/null
@@ -1,364 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iceberg.spark.source;
-
-import java.lang.reflect.Array;
-import java.nio.ByteBuffer;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
-import org.apache.iceberg.PartitionField;
-import org.apache.iceberg.PartitionSpec;
-import org.apache.iceberg.Schema;
-import org.apache.iceberg.StructLike;
-import org.apache.iceberg.relocated.com.google.common.collect.Maps;
-import org.apache.iceberg.spark.SparkSchemaUtil;
-import org.apache.iceberg.transforms.Transform;
-import org.apache.iceberg.types.Type;
-import org.apache.iceberg.types.TypeUtil;
-import org.apache.iceberg.types.Types;
-import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.types.DataType;
-import org.apache.spark.sql.types.Decimal;
-import org.apache.spark.unsafe.types.UTF8String;
-
-class PartitionKey implements StructLike {
-
- private final PartitionSpec spec;
- private final int size;
- private final Object[] partitionTuple;
- private final Transform[] transforms;
- private final Accessor[] accessors;
-
- @SuppressWarnings("unchecked")
- PartitionKey(PartitionSpec spec, Schema inputSchema) {
- this.spec = spec;
-
- List fields = spec.fields();
- this.size = fields.size();
- this.partitionTuple = new Object[size];
- this.transforms = new Transform[size];
- this.accessors = (Accessor[]) Array.newInstance(Accessor.class, size);
-
- Schema schema = spec.schema();
- Map> newAccessors = buildAccessors(inputSchema);
- for (int i = 0; i < size; i += 1) {
- PartitionField field = fields.get(i);
- Accessor accessor = newAccessors.get(field.sourceId());
- if (accessor == null) {
- throw new RuntimeException(
- "Cannot build accessor for field: " + schema.findField(field.sourceId()));
- }
- this.accessors[i] = accessor;
- this.transforms[i] = field.transform();
- }
- }
-
- private PartitionKey(PartitionKey toCopy) {
- this.spec = toCopy.spec;
- this.size = toCopy.size;
- this.partitionTuple = new Object[toCopy.partitionTuple.length];
- this.transforms = toCopy.transforms;
- this.accessors = toCopy.accessors;
-
- for (int i = 0; i < partitionTuple.length; i += 1) {
- this.partitionTuple[i] = defensiveCopyIfNeeded(toCopy.partitionTuple[i]);
- }
- }
-
- private Object defensiveCopyIfNeeded(Object obj) {
- if (obj instanceof UTF8String) {
- // bytes backing the UTF8 string might be reused
- byte[] bytes = ((UTF8String) obj).getBytes();
- return UTF8String.fromBytes(Arrays.copyOf(bytes, bytes.length));
- }
- return obj;
- }
-
- @Override
- public String toString() {
- StringBuilder sb = new StringBuilder();
- sb.append("[");
- for (int i = 0; i < partitionTuple.length; i += 1) {
- if (i > 0) {
- sb.append(", ");
- }
- sb.append(partitionTuple[i]);
- }
- sb.append("]");
- return sb.toString();
- }
-
- PartitionKey copy() {
- return new PartitionKey(this);
- }
-
- String toPath() {
- return spec.partitionToPath(this);
- }
-
- @SuppressWarnings("unchecked")
- void partition(InternalRow row) {
- for (int i = 0; i < partitionTuple.length; i += 1) {
- Transform transform = transforms[i];
- partitionTuple[i] = transform.apply(accessors[i].get(row));
- }
- }
-
- @Override
- public int size() {
- return size;
- }
-
- @Override
- @SuppressWarnings("unchecked")
- public T get(int pos, Class javaClass) {
- return javaClass.cast(partitionTuple[pos]);
- }
-
- @Override
- public void set(int pos, T value) {
- partitionTuple[pos] = value;
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) {
- return true;
- } else if (!(o instanceof PartitionKey)) {
- return false;
- }
-
- PartitionKey that = (PartitionKey) o;
- return Arrays.equals(partitionTuple, that.partitionTuple);
- }
-
- @Override
- public int hashCode() {
- return Arrays.hashCode(partitionTuple);
- }
-
- private interface Accessor {
- Object get(T container);
- }
-
- private static Map> buildAccessors(Schema schema) {
- return TypeUtil.visit(schema, new BuildPositionAccessors());
- }
-
- private static Accessor newAccessor(int position, Type type) {
- switch (type.typeId()) {
- case STRING:
- return new StringAccessor(position, SparkSchemaUtil.convert(type));
- case DECIMAL:
- return new DecimalAccessor(position, SparkSchemaUtil.convert(type));
- case BINARY:
- return new BytesAccessor(position, SparkSchemaUtil.convert(type));
- default:
- return new PositionAccessor(position, SparkSchemaUtil.convert(type));
- }
- }
-
- private static Accessor newAccessor(int position, boolean isOptional, Types.StructType type,
- Accessor accessor) {
- int size = type.fields().size();
- if (isOptional) {
- // the wrapped position handles null layers
- return new WrappedPositionAccessor(position, size, accessor);
- } else if (accessor.getClass() == PositionAccessor.class) {
- return new Position2Accessor(position, size, (PositionAccessor) accessor);
- } else if (accessor instanceof Position2Accessor) {
- return new Position3Accessor(position, size, (Position2Accessor) accessor);
- } else {
- return new WrappedPositionAccessor(position, size, accessor);
- }
- }
-
- private static class BuildPositionAccessors
- extends TypeUtil.SchemaVisitor>> {
- @Override
- public Map> schema(
- Schema schema, Map> structResult) {
- return structResult;
- }
-
- @Override
- public Map> struct(
- Types.StructType struct, List>> fieldResults) {
- Map> accessors = Maps.newHashMap();
- List fields = struct.fields();
- for (int i = 0; i < fieldResults.size(); i += 1) {
- Types.NestedField field = fields.get(i);
- Map> result = fieldResults.get(i);
- if (result != null) {
- for (Map.Entry> entry : result.entrySet()) {
- accessors.put(entry.getKey(), newAccessor(i, field.isOptional(),
- field.type().asNestedType().asStructType(), entry.getValue()));
- }
- } else {
- accessors.put(field.fieldId(), newAccessor(i, field.type()));
- }
- }
-
- if (accessors.isEmpty()) {
- return null;
- }
-
- return accessors;
- }
-
- @Override
- public Map> field(
- Types.NestedField field, Map> fieldResult) {
- return fieldResult;
- }
- }
-
- private static class PositionAccessor implements Accessor {
- private final DataType type;
- private int position;
-
- private PositionAccessor(int position, DataType type) {
- this.position = position;
- this.type = type;
- }
-
- @Override
- public Object get(InternalRow row) {
- if (row.isNullAt(position)) {
- return null;
- }
- return row.get(position, type);
- }
-
- DataType type() {
- return type;
- }
-
- int position() {
- return position;
- }
- }
-
- private static class StringAccessor extends PositionAccessor {
- private StringAccessor(int position, DataType type) {
- super(position, type);
- }
-
- @Override
- public Object get(InternalRow row) {
- if (row.isNullAt(position())) {
- return null;
- }
- return row.get(position(), type()).toString();
- }
- }
-
- private static class DecimalAccessor extends PositionAccessor {
- private DecimalAccessor(int position, DataType type) {
- super(position, type);
- }
-
- @Override
- public Object get(InternalRow row) {
- if (row.isNullAt(position())) {
- return null;
- }
- return ((Decimal) row.get(position(), type())).toJavaBigDecimal();
- }
- }
-
- private static class BytesAccessor extends PositionAccessor {
- private BytesAccessor(int position, DataType type) {
- super(position, type);
- }
-
- @Override
- public Object get(InternalRow row) {
- if (row.isNullAt(position())) {
- return null;
- }
- return ByteBuffer.wrap((byte[]) row.get(position(), type()));
- }
- }
-
- private static class Position2Accessor implements Accessor {
- private final int p0;
- private final int size0;
- private final int p1;
- private final DataType type;
-
- private Position2Accessor(int position, int size, PositionAccessor wrapped) {
- this.p0 = position;
- this.size0 = size;
- this.p1 = wrapped.position;
- this.type = wrapped.type;
- }
-
- @Override
- public Object get(InternalRow row) {
- return row.getStruct(p0, size0).get(p1, type);
- }
- }
-
- private static class Position3Accessor implements Accessor {
- private final int p0;
- private final int size0;
- private final int p1;
- private final int size1;
- private final int p2;
- private final DataType type;
-
- private Position3Accessor(int position, int size, Position2Accessor wrapped) {
- this.p0 = position;
- this.size0 = size;
- this.p1 = wrapped.p0;
- this.size1 = wrapped.size0;
- this.p2 = wrapped.p1;
- this.type = wrapped.type;
- }
-
- @Override
- public Object get(InternalRow row) {
- return row.getStruct(p0, size0).getStruct(p1, size1).get(p2, type);
- }
- }
-
- private static class WrappedPositionAccessor implements Accessor {
- private final int position;
- private final int size;
- private final Accessor accessor;
-
- private WrappedPositionAccessor(int position, int size, Accessor accessor) {
- this.position = position;
- this.size = size;
- this.accessor = accessor;
- }
-
- @Override
- public Object get(InternalRow row) {
- InternalRow inner = row.getStruct(position, size);
- if (inner != null) {
- return accessor.get(inner);
- }
- return null;
- }
- }
-}
diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/PartitionedWriter.java b/spark/src/main/java/org/apache/iceberg/spark/source/PartitionedWriter.java
index 7cc0c9f90555..0ead766f38c7 100644
--- a/spark/src/main/java/org/apache/iceberg/spark/source/PartitionedWriter.java
+++ b/spark/src/main/java/org/apache/iceberg/spark/source/PartitionedWriter.java
@@ -22,11 +22,13 @@
import java.io.IOException;
import java.util.Set;
import org.apache.iceberg.FileFormat;
+import org.apache.iceberg.PartitionKey;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Schema;
import org.apache.iceberg.io.FileIO;
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.iceberg.relocated.com.google.common.collect.Sets;
+import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.spark.sql.catalyst.InternalRow;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -35,17 +37,19 @@ class PartitionedWriter extends BaseWriter {
private static final Logger LOG = LoggerFactory.getLogger(PartitionedWriter.class);
private final PartitionKey key;
+ private final InternalRowWrapper wrapper;
private final Set completedPartitions = Sets.newHashSet();
PartitionedWriter(PartitionSpec spec, FileFormat format, SparkAppenderFactory appenderFactory,
OutputFileFactory fileFactory, FileIO io, long targetFileSize, Schema writeSchema) {
super(spec, format, appenderFactory, fileFactory, io, targetFileSize);
this.key = new PartitionKey(spec, writeSchema);
+ this.wrapper = new InternalRowWrapper(SparkSchemaUtil.convert(writeSchema));
}
@Override
public void write(InternalRow row) throws IOException {
- key.partition(row);
+ key.partition(wrapper.wrap(row));
PartitionKey currentKey = getCurrentKey();
if (!key.equals(currentKey)) {