Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import io.airlift.json.JsonCodec;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
Expand Down Expand Up @@ -126,6 +127,7 @@
import org.apache.iceberg.Transaction;
import org.apache.iceberg.UpdatePartitionSpec;
import org.apache.iceberg.UpdateProperties;
import org.apache.iceberg.UpdateSchema;
import org.apache.iceberg.UpdateStatistics;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.expressions.Expressions;
Expand Down Expand Up @@ -1574,16 +1576,72 @@ public void setColumnType(ConnectorSession session, ConnectorTableHandle tableHa
verify(column.isBaseColumn(), "Cannot change nested field types");

Table icebergTable = catalog.loadTable(session, table.getSchemaTableName());
Type sourceType = icebergTable.schema().findType(column.getName());
Type newType = toIcebergType(type);
try {
icebergTable.updateSchema()
.updateColumn(column.getName(), toIcebergType(type).asPrimitiveType())
.commit();
UpdateSchema schemaUpdate = icebergTable.updateSchema();
buildUpdateSchema(column.getName(), sourceType, newType, schemaUpdate);
schemaUpdate.commit();
}
catch (RuntimeException e) {
throw new TrinoException(ICEBERG_COMMIT_ERROR, "Failed to set column type: " + firstNonNull(e.getMessage(), e), e);
}
}

private static void buildUpdateSchema(String name, Type sourceType, Type newType, UpdateSchema schemaUpdate)
{
if (sourceType.equals(newType)) {
return;
}
if (sourceType.isPrimitiveType() && newType.isPrimitiveType()) {
schemaUpdate.updateColumn(name, newType.asPrimitiveType());
return;
}
if (sourceType instanceof StructType sourceRowType && newType instanceof StructType newRowType) {
// Add, update or delete fields
List<NestedField> fields = Streams.concat(sourceRowType.fields().stream(), newRowType.fields().stream())
.distinct()
.collect(toImmutableList());
for (NestedField field : fields) {
if (fieldExists(sourceRowType, field.name()) && fieldExists(newRowType, field.name())) {
buildUpdateSchema(name + "." + field.name(), sourceRowType.fieldType(field.name()), newRowType.fieldType(field.name()), schemaUpdate);
}
else if (fieldExists(newRowType, field.name())) {
schemaUpdate.addColumn(name, field.name(), field.type());
}
else {
schemaUpdate.deleteColumn(name + "." + field.name());
}
}

// Order fields based on the new column type
String currentName = null;
for (NestedField field : newRowType.fields()) {
String path = name + "." + field.name();
if (currentName == null) {
schemaUpdate.moveFirst(path);
}
else {
schemaUpdate.moveAfter(path, currentName);
}
currentName = path;
}

return;
}
throw new IllegalArgumentException("Cannot change type from %s to %s".formatted(sourceType, newType));
}

private static boolean fieldExists(StructType structType, String fieldName)
{
for (NestedField field : structType.fields()) {
if (field.name().equals(fieldName)) {
return true;
}
}
return false;
}

private List<ColumnMetadata> getColumnMetadatas(Schema schema)
{
ImmutableList.Builder<ColumnMetadata> columns = ImmutableList.builder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6349,7 +6349,6 @@ protected Optional<SetColumnTypeSetup> filterSetColumnTypesDataProvider(SetColum
case "decimal(5,3) -> decimal(5,2)":
case "varchar -> char(20)":
case "array(integer) -> array(bigint)":
case "row(x integer) -> row(x bigint)":
// Iceberg allows updating column types if the update is safe. Safe updates are:
// - int to bigint
// - float to double
Expand All @@ -6367,7 +6366,7 @@ protected Optional<SetColumnTypeSetup> filterSetColumnTypesDataProvider(SetColum
@Override
protected void verifySetColumnTypeFailurePermissible(Throwable e)
{
assertThat(e).hasMessageMatching(".*(Cannot change column type|not supported for Iceberg|Not a primitive type).*");
assertThat(e).hasMessageMatching(".*(Cannot change column type|not supported for Iceberg|Not a primitive type|Cannot change type ).*");
}

private Session prepareCleanUpSession()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.trino.testing.sql.TestTable;
import org.testng.annotations.Test;

import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand Down Expand Up @@ -73,4 +74,15 @@ protected Session withSmallRowGroups(Session session)
.setCatalogSessionProperty("iceberg", "parquet_writer_batch_size", "10")
.build();
}

@Override
protected Optional<SetColumnTypeSetup> filterSetColumnTypesDataProvider(SetColumnTypeSetup setup)
{
switch ("%s -> %s".formatted(setup.sourceColumnType(), setup.newColumnType())) {
case "row(x integer) -> row(y integer)":
// TODO https://github.com/trinodb/trino/issues/15822 The connector returns incorrect NULL when a field in row type doesn't exist in Parquet files
return Optional.of(setup.withNewValueLiteral("NULL"));
Comment on lines +83 to +84
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Single-field row types aren't very common.
Let's make sure we have a test for renaming some, but not all fields in a row (eg have a row with two fields).

(We can keep this one too)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sent #15957

}
return super.filterSetColumnTypesDataProvider(setup);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2591,6 +2591,50 @@ public static Object[][] testSetColumnTypeDataProvider()
});
}

@Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS}, dataProvider = "storageFormats")
public void testTrinoAlterStructColumnType(StorageFormat storageFormat)
{
String baseTableName = "test_trino_alter_row_column_type_" + randomNameSuffix();
String trinoTableName = trinoTableName(baseTableName);
String sparkTableName = sparkTableName(baseTableName);

onTrino().executeQuery("CREATE TABLE " + trinoTableName + " " +
"WITH (format = '" + storageFormat + "')" +
"AS SELECT CAST(row(1, 2) AS row(a integer, b integer)) AS col");

// Add a nested field
onTrino().executeQuery("ALTER TABLE " + trinoTableName + " ALTER COLUMN col SET DATA TYPE row(a integer, b integer, c integer)");
assertEquals(getColumnType(baseTableName, "col"), "row(a integer, b integer, c integer)");
assertThat(onSpark().executeQuery("SELECT col.a, col.b, col.c FROM " + sparkTableName)).containsOnly(row(1, 2, null));
assertThat(onTrino().executeQuery("SELECT col.a, col.b, col.c FROM " + trinoTableName)).containsOnly(row(1, 2, null));

// Update a nested field
onTrino().executeQuery("ALTER TABLE " + trinoTableName + " ALTER COLUMN col SET DATA TYPE row(a integer, b bigint, c integer)");
assertEquals(getColumnType(baseTableName, "col"), "row(a integer, b bigint, c integer)");
assertThat(onSpark().executeQuery("SELECT col.a, col.b, col.c FROM " + sparkTableName)).containsOnly(row(1, 2, null));
assertThat(onTrino().executeQuery("SELECT col.a, col.b, col.c FROM " + trinoTableName)).containsOnly(row(1, 2, null));

// Drop a nested field
onTrino().executeQuery("ALTER TABLE " + trinoTableName + " ALTER COLUMN col SET DATA TYPE row(a integer, c integer)");
assertEquals(getColumnType(baseTableName, "col"), "row(a integer, c integer)");
assertThat(onSpark().executeQuery("SELECT col.a, col.c FROM " + sparkTableName)).containsOnly(row(1, null));
assertThat(onTrino().executeQuery("SELECT col.a, col.c FROM " + trinoTableName)).containsOnly(row(1, null));

// Adding a nested field with the same name doesn't restore the old data
onTrino().executeQuery("ALTER TABLE " + trinoTableName + " ALTER COLUMN col SET DATA TYPE row(a integer, c integer, b bigint)");
assertEquals(getColumnType(baseTableName, "col"), "row(a integer, c integer, b bigint)");
assertThat(onSpark().executeQuery("SELECT col.a, col.c, col.b FROM " + sparkTableName)).containsOnly(row(1, null, null));
assertThat(onTrino().executeQuery("SELECT col.a, col.c, col.b FROM " + trinoTableName)).containsOnly(row(1, null, null));

// Reorder fields
onTrino().executeQuery("ALTER TABLE " + trinoTableName + " ALTER COLUMN col SET DATA TYPE row(c integer, b bigint, a integer)");
assertEquals(getColumnType(baseTableName, "col"), "row(c integer, b bigint, a integer)");
assertThat(onSpark().executeQuery("SELECT col.b, col.c, col.a FROM " + sparkTableName)).containsOnly(row(null, null, 1));
assertThat(onTrino().executeQuery("SELECT col.b, col.c, col.a FROM " + trinoTableName)).containsOnly(row(null, null, 1));

onTrino().executeQuery("DROP TABLE " + trinoTableName);
}

@Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS}, dataProvider = "testSparkAlterColumnType")
public void testSparkAlterColumnType(StorageFormat storageFormat, String sourceColumnType, String sourceValueLiteral, String newColumnType, Object newValue)
{
Expand Down Expand Up @@ -2637,6 +2681,50 @@ public static Object[][] testSparkAlterColumnType()
});
}

@Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS}, dataProvider = "storageFormats")
public void testSparkAlterStructColumnType(StorageFormat storageFormat)
{
String baseTableName = "test_spark_alter_struct_column_type_" + randomNameSuffix();
String trinoTableName = trinoTableName(baseTableName);
String sparkTableName = sparkTableName(baseTableName);

onSpark().executeQuery("CREATE TABLE " + sparkTableName +
" TBLPROPERTIES ('write.format.default' = '" + storageFormat + "')" +
"AS SELECT named_struct('a', 1, 'b', 2) AS col");

// Add a nested field
onSpark().executeQuery("ALTER TABLE " + sparkTableName + " ADD COLUMN col.c integer");
assertEquals(getColumnType(baseTableName, "col"), "row(a integer, b integer, c integer)");
assertThat(onSpark().executeQuery("SELECT col.a, col.b, col.c FROM " + sparkTableName)).containsOnly(row(1, 2, null));
assertThat(onTrino().executeQuery("SELECT col.a, col.b, col.c FROM " + trinoTableName)).containsOnly(row(1, 2, null));

// Update a nested field
onSpark().executeQuery("ALTER TABLE " + sparkTableName + " ALTER COLUMN col.b TYPE bigint");
assertEquals(getColumnType(baseTableName, "col"), "row(a integer, b bigint, c integer)");
assertThat(onSpark().executeQuery("SELECT col.a, col.b, col.c FROM " + sparkTableName)).containsOnly(row(1, 2, null));
assertThat(onTrino().executeQuery("SELECT col.a, col.b, col.c FROM " + trinoTableName)).containsOnly(row(1, 2, null));

// Drop a nested field
onSpark().executeQuery("ALTER TABLE " + sparkTableName + " DROP COLUMN col.b");
assertEquals(getColumnType(baseTableName, "col"), "row(a integer, c integer)");
assertThat(onSpark().executeQuery("SELECT col.a, col.c FROM " + sparkTableName)).containsOnly(row(1, null));
assertThat(onTrino().executeQuery("SELECT col.a, col.c FROM " + trinoTableName)).containsOnly(row(1, null));

// Adding a nested field with the same name doesn't restore the old data
onSpark().executeQuery("ALTER TABLE " + sparkTableName + " ADD COLUMN col.b bigint");
assertEquals(getColumnType(baseTableName, "col"), "row(a integer, c integer, b bigint)");
assertThat(onSpark().executeQuery("SELECT col.a, col.c, col.b FROM " + sparkTableName)).containsOnly(row(1, null, null));
assertThat(onTrino().executeQuery("SELECT col.a, col.c, col.b FROM " + trinoTableName)).containsOnly(row(1, null, null));

// Reorder fields
onSpark().executeQuery("ALTER TABLE " + sparkTableName + " ALTER COLUMN col.a AFTER b");
assertEquals(getColumnType(baseTableName, "col"), "row(c integer, b bigint, a integer)");
assertThat(onSpark().executeQuery("SELECT col.b, col.c, col.a FROM " + sparkTableName)).containsOnly(row(null, null, 1));
assertThat(onTrino().executeQuery("SELECT col.b, col.c, col.a FROM " + trinoTableName)).containsOnly(row(null, null, 1));

onSpark().executeQuery("DROP TABLE " + sparkTableName);
}

private String getColumnType(String tableName, String columnName)
{
return (String) onTrino().executeQuery("SELECT data_type FROM " + TRINO_CATALOG + ".information_schema.columns " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2224,7 +2224,7 @@ public void testSetColumnTypes(SetColumnTypeSetup setup)
assertEquals(getColumnType(table.getName(), "col"), setup.newColumnType);
assertThat(query("SELECT * FROM " + table.getName()))
.skippingTypesCheck()
.matches("VALUES " + setup.newValueLiteral);
.matches("SELECT " + setup.newValueLiteral);
}
catch (Exception e) {
verifyUnsupportedTypeException(e, setup.sourceColumnType);
Expand Down Expand Up @@ -2273,6 +2273,13 @@ private List<SetColumnTypeSetup> setColumnTypeSetupData()
.add(new SetColumnTypeSetup("varchar", "'varchar-to-char'", "char(20)"))
.add(new SetColumnTypeSetup("array(integer)", "array[1]", "array(bigint)"))
.add(new SetColumnTypeSetup("row(x integer)", "row(1)", "row(x bigint)"))
.add(new SetColumnTypeSetup("row(x integer)", "row(1)", "row(y integer)", "cast(row(NULL) as row(x integer))")) // rename a field
.add(new SetColumnTypeSetup("row(x integer)", "row(1)", "row(x integer, y integer)", "cast(row(1, NULL) as row(x integer, y integer))")) // add a new field
.add(new SetColumnTypeSetup("row(x integer, y integer)", "row(1, 2)", "row(x integer)", "cast(row(1) as row(x integer))")) // remove an existing field
.add(new SetColumnTypeSetup("row(x integer, y integer)", "row(1, 2)", "row(y integer, x integer)", "cast(row(2, 1) as row(y integer, x integer))")) // reorder fields
.add(new SetColumnTypeSetup("row(x integer, y integer)", "row(1, 2)", "row(z integer, y integer, x integer)", "cast(row(null, 2, 1) as row(z integer, y integer, x integer))")) // reorder fields with a new field
.add(new SetColumnTypeSetup("row(x row(nested integer))", "row(row(1))", "row(x row(nested bigint))", "cast(row(row(1)) as row(x row(nested bigint)))")) // update a nested field
.add(new SetColumnTypeSetup("row(x row(a integer, b integer))", "row(row(1, 2))", "row(x row(b integer, a integer))", "cast(row(row(2, 1)) as row(x row(b integer, a integer)))")) // reorder a nested field
.build();
}

Expand Down