Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,13 @@ public static ColumnIdentity primitiveColumnIdentity(int id, String name)
}

public static ColumnIdentity createColumnIdentity(Types.NestedField column)
{
return createColumnIdentity(column.name(), column);
}

public static ColumnIdentity createColumnIdentity(String name, Types.NestedField column)
{
int id = column.fieldId();
String name = column.name();
org.apache.iceberg.types.Type fieldType = column.type();

if (!fieldType.isNestedType()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@
import com.google.common.collect.Iterables;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;
import org.apache.iceberg.types.Types;

import java.util.List;
import java.util.Objects;
import java.util.Optional;

import static io.trino.plugin.iceberg.ColumnIdentity.createColumnIdentity;
import static io.trino.plugin.iceberg.TypeConverter.toTrinoType;
import static java.util.Objects.requireNonNull;

public class IcebergColumnHandle
Expand Down Expand Up @@ -166,4 +170,24 @@ public String toString()
{
return getId() + ":" + getName() + ":" + type.getDisplayName();
}

public static IcebergColumnHandle create(Types.NestedField column, TypeManager typeManager)
{
return new IcebergColumnHandle(
createColumnIdentity(column),
toTrinoType(column.type(), typeManager),
ImmutableList.of(),
toTrinoType(column.type(), typeManager),
Optional.ofNullable(column.doc()));
}

public static IcebergColumnHandle create(String name, Types.NestedField column, TypeManager typeManager)
{
return new IcebergColumnHandle(
createColumnIdentity(name, column),
toTrinoType(column.type(), typeManager),
ImmutableList.of(),
toTrinoType(column.type(), typeManager),
Optional.ofNullable(column.doc()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
import static io.trino.plugin.iceberg.IcebergTableProperties.PARTITIONING_PROPERTY;
import static io.trino.plugin.iceberg.IcebergTableProperties.getPartitioning;
import static io.trino.plugin.iceberg.IcebergUtil.deserializePartitionValue;
import static io.trino.plugin.iceberg.IcebergUtil.getColumns;
import static io.trino.plugin.iceberg.IcebergUtil.getAllColumns;
import static io.trino.plugin.iceberg.IcebergUtil.getFileFormat;
import static io.trino.plugin.iceberg.IcebergUtil.getPartitionKeys;
import static io.trino.plugin.iceberg.IcebergUtil.getTableComment;
Expand Down Expand Up @@ -263,7 +263,7 @@ public ConnectorTableProperties getTableProperties(ConnectorSession session, Con
DiscretePredicates discretePredicates = null;
if (!partitionSourceIds.isEmpty()) {
// Extract identity partition columns
Map<Integer, IcebergColumnHandle> columns = getColumns(icebergTable.schema(), typeManager).stream()
Map<Integer, IcebergColumnHandle> columns = getAllColumns(icebergTable.schema(), typeManager).stream()
.filter(column -> partitionSourceIds.contains(column.getId()))
.collect(toImmutableMap(IcebergColumnHandle::getId, Function.identity()));

Expand Down Expand Up @@ -340,7 +340,7 @@ public Map<String, ColumnHandle> getColumnHandles(ConnectorSession session, Conn
{
IcebergTableHandle table = (IcebergTableHandle) tableHandle;
Table icebergTable = catalog.loadTable(session, table.getSchemaTableName());
return getColumns(icebergTable.schema(), typeManager).stream()
return getAllColumns(icebergTable.schema(), typeManager).stream()
.collect(toImmutableMap(IcebergColumnHandle::getName, identity()));
}

Expand Down Expand Up @@ -432,7 +432,7 @@ public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, Con
tableMetadata.getTable().getTableName(),
SchemaParser.toJson(transaction.table().schema()),
PartitionSpecParser.toJson(transaction.table().spec()),
getColumns(transaction.table().schema(), typeManager),
getAllColumns(transaction.table().schema(), typeManager),
transaction.table().location(),
getFileFormat(transaction.table()),
transaction.table().properties());
Expand All @@ -458,7 +458,7 @@ private Optional<ConnectorNewTableLayout> getWriteLayout(Schema tableSchema, Par
return Optional.empty();
}

Map<Integer, IcebergColumnHandle> columnById = getColumns(tableSchema, typeManager).stream()
Map<Integer, IcebergColumnHandle> columnById = getAllColumns(tableSchema, typeManager).stream()
.collect(toImmutableMap(IcebergColumnHandle::getId, identity()));

List<IcebergColumnHandle> partitioningColumns = partitionSpec.fields().stream()
Expand Down Expand Up @@ -492,7 +492,7 @@ public ConnectorInsertTableHandle beginInsert(ConnectorSession session, Connecto
table.getTableName(),
SchemaParser.toJson(icebergTable.schema()),
PartitionSpecParser.toJson(icebergTable.spec()),
getColumns(icebergTable.schema(), typeManager),
getAllColumns(icebergTable.schema(), typeManager),
icebergTable.location(),
getFileFormat(icebergTable),
icebergTable.properties());
Expand Down Expand Up @@ -914,7 +914,7 @@ public ConnectorInsertTableHandle beginRefreshMaterializedView(ConnectorSession
table.getTableName(),
SchemaParser.toJson(icebergTable.schema()),
PartitionSpecParser.toJson(icebergTable.spec()),
getColumns(icebergTable.schema(), typeManager),
getAllColumns(icebergTable.schema(), typeManager),
icebergTable.location(),
getFileFormat(icebergTable),
icebergTable.properties());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@
import org.apache.iceberg.Schema;
import org.apache.iceberg.io.LocationProvider;
import org.apache.iceberg.transforms.Transform;
import org.apache.iceberg.types.Types;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
Expand All @@ -66,6 +69,7 @@
import static io.trino.plugin.iceberg.PartitionTransforms.getColumnTransform;
import static io.trino.plugin.iceberg.util.Timestamps.getTimestampTz;
import static io.trino.plugin.iceberg.util.Timestamps.timestampTzToMicros;
import static io.trino.spi.block.ColumnarRow.toColumnarRow;
import static io.trino.spi.type.Decimals.readBigDecimal;
import static io.trino.spi.type.TimeType.TIME_MICROS;
import static io.trino.spi.type.TimestampType.TIMESTAMP_MICROS;
Expand Down Expand Up @@ -129,7 +133,7 @@ public IcebergPageSink(
this.session = requireNonNull(session, "session is null");
this.fileFormat = requireNonNull(fileFormat, "fileFormat is null");
this.maxOpenWriters = maxOpenWriters;
this.pagePartitioner = new PagePartitioner(pageIndexerFactory, toPartitionColumns(inputColumns, partitionSpec));
this.pagePartitioner = new PagePartitioner(pageIndexerFactory, toPartitionColumns(inputColumns, partitionSpec, outputSchema));
}

@Override
Expand Down Expand Up @@ -326,14 +330,24 @@ private static Optional<PartitionData> getPartitionData(List<PartitionColumn> co
Object[] values = new Object[columns.size()];
for (int i = 0; i < columns.size(); i++) {
PartitionColumn column = columns.get(i);
Block block = page.getBlock(column.getSourceChannel());
Block block = getPartitionBlock(column, page);
Type type = column.getSourceType();
Object value = getIcebergValue(block, position, type);
values[i] = applyTransform(column.getField().transform(), value);
}
return Optional.of(new PartitionData(values));
}

private static Block getPartitionBlock(PartitionColumn column, Page page)
{
Iterator<Integer> pos = column.getSourceChannels().listIterator();
Block block = page.getBlock(pos.next());
while (pos.hasNext()) {
block = toColumnarRow(block).getField(pos.next());
}
return block;
}

@SuppressWarnings("unchecked")
private static Object applyTransform(Transform<?, ?> transform, Object value)
{
Expand Down Expand Up @@ -384,24 +398,32 @@ public static Object getIcebergValue(Block block, int position, Type type)
throw new UnsupportedOperationException("Type not supported as partition column: " + type.getDisplayName());
}

private static List<PartitionColumn> toPartitionColumns(List<IcebergColumnHandle> handles, PartitionSpec partitionSpec)
private static List<PartitionColumn> toPartitionColumns(List<IcebergColumnHandle> handles, PartitionSpec partitionSpec, Schema schema)
{
Map<Integer, Integer> idChannels = new HashMap<>();
for (int i = 0; i < handles.size(); i++) {
idChannels.put(handles.get(i).getId(), i);
}

return partitionSpec.fields().stream()
.map(field -> {
Integer channel = idChannels.get(field.sourceId());
checkArgument(channel != null, "partition field not found: %s", field);
Type inputType = handles.get(channel).getType();
ColumnTransform transform = getColumnTransform(field, inputType);
return new PartitionColumn(field, channel, inputType, transform.getType(), transform.getBlockTransform());
})
.map(field -> getPartitionColumn(field, handles, schema.asStruct(), idChannels))
.collect(toImmutableList());
}

private static PartitionColumn getPartitionColumn(PartitionField field, List<IcebergColumnHandle> handles, Types.StructType schema, Map<Integer, Integer> idChannels)
{
List<Integer> sourceIds = null;
try {
sourceIds = IcebergUtil.getIndexPathToField(schema, field.sourceId());
}
catch (Exception e) {
checkArgument(sourceIds != null, "partition field not found: %s", field);
}
Type inputType = handles.get(idChannels.get(field.sourceId())).getType();
ColumnTransform transform = getColumnTransform(field, inputType);
return new PartitionColumn(field, sourceIds, inputType, transform.getType(), transform.getBlockTransform());
}

private static class WriteContext
{
private final IcebergFileWriter writer;
Expand Down Expand Up @@ -449,14 +471,26 @@ public int[] partitionPage(Page page)
Block[] blocks = new Block[columns.size()];
for (int i = 0; i < columns.size(); i++) {
PartitionColumn column = columns.get(i);
Block block = page.getBlock(column.getSourceChannel());
Block block = getPartitionBlock(column, page);
blocks[i] = column.getBlockTransform().apply(block);
}
Page transformed = new Page(page.getPositionCount(), blocks);

return pageIndexer.indexPage(transformed);
}

private Block getPartitionBlock(PartitionColumn column, Page page)
{
ListIterator<Integer> iterator = column.getSourceChannels().listIterator();
Block block = page.getBlock(iterator.next());

while (iterator.hasNext()) {
block = toColumnarRow(block).getField(iterator.next());
}

return block;
}

public int getMaxIndex()
{
return pageIndexer.getMaxIndex();
Expand All @@ -471,15 +505,15 @@ public List<PartitionColumn> getColumns()
private static class PartitionColumn
{
private final PartitionField field;
private final int sourceChannel;
private final List<Integer> sourceChannels;
private final Type sourceType;
private final Type resultType;
private final Function<Block, Block> blockTransform;

public PartitionColumn(PartitionField field, int sourceChannel, Type sourceType, Type resultType, Function<Block, Block> blockTransform)
public PartitionColumn(PartitionField field, List<Integer> sourceChannels, Type sourceType, Type resultType, Function<Block, Block> blockTransform)
{
this.field = requireNonNull(field, "field is null");
this.sourceChannel = sourceChannel;
this.sourceChannels = sourceChannels;
this.sourceType = requireNonNull(sourceType, "sourceType is null");
this.resultType = requireNonNull(resultType, "resultType is null");
this.blockTransform = requireNonNull(blockTransform, "blockTransform is null");
Expand All @@ -492,7 +526,12 @@ public PartitionField getField()

public int getSourceChannel()
{
return sourceChannel;
return sourceChannels.get(0);
}

public List<Integer> getSourceChannels()
{
return sourceChannels;
}

public Type getSourceType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import org.apache.iceberg.io.LocationProvider;
import org.apache.iceberg.types.Type.PrimitiveType;
import org.apache.iceberg.types.TypeUtil;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.types.Types.NestedField;
import org.apache.iceberg.types.Types.StructType;

Expand All @@ -56,6 +57,7 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Base64;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -150,6 +152,72 @@ public static Table getIcebergTableWithMetadata(
return new BaseTable(operations, quotedTableName(table));
}

public static List<IcebergColumnHandle> getAllColumns(Schema schema, TypeManager typeManager)
{
return TypeUtil
.indexByName(schema.asStruct())
.keySet()
.stream()
.map(name -> IcebergColumnHandle.create(name, schema.findField(name), typeManager))
.collect(toImmutableList());
}

public static List<String> getNestedColumnNames(Types.StructType schema, Integer sourceId)
{
Map<Integer, Integer> parentIndex = TypeUtil.indexParents(schema);
Map<Integer, Types.NestedField> idIndex = TypeUtil.indexById(schema);
LinkedList<String> parentColumns = new LinkedList();

parentColumns.addFirst(idIndex.get(sourceId).name());
Integer current = parentIndex.get(sourceId);

while (current != null) {
parentColumns.addFirst(idIndex.get(current).name());
current = parentIndex.get(current);
}
return parentColumns;
}

private static Integer getFieldPosFromSchema(String name, Types.StructType schema) throws Exception
{
for (int i = 0; i < schema.fields().size(); i++) {
if (schema.fields().get(i).name().contentEquals(name)) {
return i;
}
}
throw new IllegalArgumentException("Could not find field " + name + " in schema");
}

public static List<Integer> getIndexPathToField(Types.StructType schema, Integer sourceId) throws Exception
{
return getIndexPathToField(schema, getNestedColumnNames(schema, sourceId));
}

public static List<Integer> getIndexPathToField(Types.StructType schema, List<String> fieldName) throws Exception
{
List<Integer> sourceIds = new LinkedList<>();
Types.StructType current = schema;

// Iterate over field names while finding position in schema
for (int i = 0; i < fieldName.size(); i++) {
String name = fieldName.get(i);
sourceIds.add(getFieldPosFromSchema(name, current));

if (current.field(name).type().isStructType()) {
current = current.field(name).type().asStructType();
}
else if (i + 1 == fieldName.size()) {
break;
}
else {
String fullFieldName = String.join(".", fieldName);
throw new IllegalArgumentException("Could not find field " + fullFieldName + " in schema");
}
}

return sourceIds;
}

public static long resolveSnapshotId(Table table, long snapshotId)
{
if (table.snapshot(snapshotId) != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

public final class PartitionFields
{
private static final String NAME = "[a-z_][a-z0-9_]*";
private static final String IDENTIFIER = "[a-z_][a-z0-9_]*";
private static final String NAME = IDENTIFIER + "(\\." + IDENTIFIER + ")*";
private static final String FUNCTION_ARGUMENT_NAME = "\\((" + NAME + ")\\)";
private static final String FUNCTION_ARGUMENT_NAME_AND_INT = "\\((" + NAME + "), *(\\d+)\\)";

Expand Down Expand Up @@ -65,8 +66,8 @@ public static void parsePartitionField(PartitionSpec.Builder builder, String fie
tryMatch(field, MONTH_PATTERN, match -> builder.month(match.group(1))) ||
tryMatch(field, DAY_PATTERN, match -> builder.day(match.group(1))) ||
tryMatch(field, HOUR_PATTERN, match -> builder.hour(match.group(1))) ||
tryMatch(field, BUCKET_PATTERN, match -> builder.bucket(match.group(1), parseInt(match.group(2)))) ||
tryMatch(field, TRUNCATE_PATTERN, match -> builder.truncate(match.group(1), parseInt(match.group(2)))) ||
tryMatch(field, BUCKET_PATTERN, match -> builder.bucket(match.group(1), parseInt(match.group(match.groupCount())))) ||
tryMatch(field, TRUNCATE_PATTERN, match -> builder.truncate(match.group(1), parseInt(match.group(match.groupCount())))) ||
tryMatch(field, VOID_PATTERN, match -> builder.alwaysNull(match.group(1))) ||
false;
if (!matched) {
Expand Down
Loading