Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 57 additions & 15 deletions plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,14 @@
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Strings.lenientFormat;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.plugin.hive.HiveStorageFormat.AVRO;
import static io.trino.plugin.hive.HiveStorageFormat.ORC;
import static io.trino.plugin.hive.HiveTimestampPrecision.DEFAULT_PRECISION;
import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_FIELD_PREFIX;
import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_TAG_NAME;
import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_TAG_TYPE;
import static io.trino.plugin.hive.util.HiveTypeTranslator.fromPrimitiveType;
import static io.trino.plugin.hive.util.HiveTypeTranslator.toTypeInfo;
import static io.trino.plugin.hive.util.HiveTypeTranslator.toTypeSignature;
Expand Down Expand Up @@ -219,13 +223,32 @@ public Optional<HiveType> getHiveTypeForDereferences(List<Integer> dereferences)
{
TypeInfo typeInfo = getTypeInfo();
for (int fieldIndex : dereferences) {
checkArgument(typeInfo instanceof StructTypeInfo, "typeInfo should be struct type", typeInfo);
StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo;
try {
typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex);
if (typeInfo instanceof StructTypeInfo structTypeInfo) {
try {
typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex);
}
catch (RuntimeException e) {
// return empty when failed to dereference, this could happen when partition and table schema mismatch
return Optional.empty();
}
}
catch (RuntimeException e) {
return Optional.empty();
else if (typeInfo instanceof UnionTypeInfo unionTypeInfo) {
try {
if (fieldIndex == 0) {
// union's tag field, defined in {@link io.trino.plugin.hive.util.HiveTypeTranslator#toTypeSignature}
return Optional.of(HiveType.toHiveType(UNION_FIELD_TAG_TYPE));
}
else {
typeInfo = unionTypeInfo.getAllUnionObjectTypeInfos().get(fieldIndex - 1);
}
}
catch (RuntimeException e) {
// return empty when failed to dereference, this could happen when partition and table schema mismatch
return Optional.empty();
}
}
else {
throw new IllegalArgumentException(lenientFormat("typeInfo: %s should be struct or union type", typeInfo));
}
}
return Optional.of(toHiveType(typeInfo));
Expand All @@ -235,16 +258,35 @@ public List<String> getHiveDereferenceNames(List<Integer> dereferences)
{
ImmutableList.Builder<String> dereferenceNames = ImmutableList.builder();
TypeInfo typeInfo = getTypeInfo();
for (int fieldIndex : dereferences) {
checkArgument(typeInfo instanceof StructTypeInfo, "typeInfo should be struct type", typeInfo);
StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo;

for (int i = 0; i < dereferences.size(); i++) {
int fieldIndex = dereferences.get(i);
checkArgument(fieldIndex >= 0, "fieldIndex cannot be negative");
checkArgument(fieldIndex < structTypeInfo.getAllStructFieldNames().size(),
"fieldIndex should be less than the number of fields in the struct");
String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex);
dereferenceNames.add(fieldName);
typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex);

if (typeInfo instanceof StructTypeInfo structTypeInfo) {
checkArgument(fieldIndex < structTypeInfo.getAllStructFieldNames().size(),
"fieldIndex should be less than the number of fields in the struct");

String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex);
dereferenceNames.add(fieldName);
typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex);
}
else if (typeInfo instanceof UnionTypeInfo unionTypeInfo) {
checkArgument((fieldIndex - 1) < unionTypeInfo.getAllUnionObjectTypeInfos().size(),
"fieldIndex should be less than the number of fields in the union plus tag field");

if (fieldIndex == 0) {
checkArgument(i == (dereferences.size() - 1), "Union's tag field should not have more subfields");
dereferenceNames.add(UNION_FIELD_TAG_NAME);
break;
}
else {
typeInfo = unionTypeInfo.getAllUnionObjectTypeInfos().get(fieldIndex - 1);
dereferenceNames.add(UNION_FIELD_FIELD_PREFIX + (fieldIndex - 1));
}
}
else {
throw new IllegalArgumentException(lenientFormat("typeInfo: %s should be struct or union type", typeInfo));
}
}

return dereferenceNames.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ public final class HiveTypeTranslator
{
private HiveTypeTranslator() {}

public static final String UNION_FIELD_TAG_NAME = "tag";
public static final String UNION_FIELD_FIELD_PREFIX = "field";
public static final Type UNION_FIELD_TAG_TYPE = TINYINT;

public static TypeInfo toTypeInfo(Type type)
{
requireNonNull(type, "type is null");
Expand Down Expand Up @@ -213,10 +217,10 @@ public static TypeSignature toTypeSignature(TypeInfo typeInfo, HiveTimestampPrec
UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo;
List<TypeInfo> unionObjectTypes = unionTypeInfo.getAllUnionObjectTypeInfos();
ImmutableList.Builder<TypeSignatureParameter> typeSignatures = ImmutableList.builder();
typeSignatures.add(namedField("tag", TINYINT.getTypeSignature()));
typeSignatures.add(namedField(UNION_FIELD_TAG_NAME, UNION_FIELD_TAG_TYPE.getTypeSignature()));
for (int i = 0; i < unionObjectTypes.size(); i++) {
TypeInfo unionObjectType = unionObjectTypes.get(i);
typeSignatures.add(namedField("field" + i, toTypeSignature(unionObjectType, timestampPrecision)));
typeSignatures.add(namedField(UNION_FIELD_FIELD_PREFIX + i, toTypeSignature(unionObjectType, timestampPrecision)));
}
return rowType(typeSignatures.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Arrays;
import java.util.List;

import static io.trino.testing.TestingNames.randomNameSuffix;
import static io.trino.tests.product.TestGroups.SMOKE;
import static io.trino.tests.product.utils.QueryExecutors.onHive;
import static io.trino.tests.product.utils.QueryExecutors.onTrino;
Expand Down Expand Up @@ -51,6 +52,87 @@ public static Object[][] storageFormats()
return new String[][] {{"ORC"}, {"AVRO"}};
}

@DataProvider(name = "union_dereference_test_cases")
public static Object[][] unionDereferenceTestCases()
{
String tableUnionDereference = "test_union_dereference" + randomNameSuffix();
Comment thread
groupcache4321 marked this conversation as resolved.
Outdated
// Hive insertion for union type in AVRO format has bugs, so we test on different table schemas for AVRO than ORC.
return new Object[][] {{
format(
"CREATE TABLE %s (unionLevel0 UNIONTYPE<" +
"INT, STRING>)" +
"STORED AS %s",
tableUnionDereference,
"AVRO"),
format(
"INSERT INTO TABLE %s " +
"SELECT create_union(0, 321, 'row1') " +
"UNION ALL " +
"SELECT create_union(1, 55, 'row2') ",
tableUnionDereference),
format("SELECT unionLevel0.field0 FROM %s WHERE unionLevel0.field0 IS NOT NULL", tableUnionDereference),
Arrays.asList(321),
format("SELECT unionLevel0.tag FROM %s", tableUnionDereference),
Arrays.asList((byte) 0, (byte) 1),
"DROP TABLE IF EXISTS " + tableUnionDereference},
// there is an internal issue in Hive 1.2:
// unionLevel1 is declared as unionType<String, Int>, but has to be inserted by create_union(tagId, Int, String)
{
format(
"CREATE TABLE %s (unionLevel0 UNIONTYPE<INT, STRING," +
"STRUCT<intLevel1:INT, stringLevel1:STRING, unionLevel1:UNIONTYPE<STRING, INT>>>, intLevel0 INT )" +
"STORED AS %s",
tableUnionDereference,
"AVRO"),
format(
"INSERT INTO TABLE %s " +
"SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(0, 5, 'testString'))), 8 " +
"UNION ALL " +
"SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(1, 5, 'testString'))), 8 ",
tableUnionDereference),
format("SELECT unionLevel0.field2.unionLevel1.field1 FROM %s WHERE unionLevel0.field2.unionLevel1.field1 IS NOT NULL", tableUnionDereference),
Arrays.asList(5),
format("SELECT unionLevel0.field2.unionLevel1.tag FROM %s", tableUnionDereference),
Arrays.asList((byte) 0, (byte) 1),
"DROP TABLE IF EXISTS " + tableUnionDereference},
{
format(
"CREATE TABLE %s (unionLevel0 UNIONTYPE<" +
"STRUCT<unionLevel1:UNIONTYPE<STRING, INT>>>)" +
"STORED AS %s",
tableUnionDereference,
"ORC"),
format(
"INSERT INTO TABLE %s " +
"SELECT create_union(0, named_struct('unionLevel1', create_union(0, 'testString1', 23))) " +
"UNION ALL " +
"SELECT create_union(0, named_struct('unionLevel1', create_union(1, 'testString2', 45))) ",
tableUnionDereference),
format("SELECT unionLevel0.field0.unionLevel1.field0 FROM %s WHERE unionLevel0.field0.unionLevel1.field0 IS NOT NULL", tableUnionDereference),
Arrays.asList("testString1"),
format("SELECT unionLevel0.field0.unionLevel1.tag FROM %s", tableUnionDereference),
Arrays.asList((byte) 0, (byte) 1),
"DROP TABLE IF EXISTS " + tableUnionDereference},
{
format(
"CREATE TABLE %s (unionLevel0 UNIONTYPE<INT, STRING," +
"STRUCT<intLevel1:INT, stringLevel1:STRING, unionLevel1:UNIONTYPE<STRING, INT>>>, intLevel0 INT )" +
"STORED AS %s",
tableUnionDereference,
"ORC"),
format(
"INSERT INTO TABLE %s " +
"SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(0, 'testString', 5))), 8 " +
"UNION ALL " +
"SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(1, 'testString', 5))), 8 ",
tableUnionDereference),
format("SELECT unionLevel0.field2.unionLevel1.field0 FROM %s WHERE unionLevel0.field2.unionLevel1.field0 IS NOT NULL", tableUnionDereference),
Arrays.asList("testString"),
format("SELECT unionLevel0.field2.unionLevel1.tag FROM %s", tableUnionDereference),
Arrays.asList((byte) 0, (byte) 1),
"DROP TABLE IF EXISTS " + tableUnionDereference}};
}

@Test(dataProvider = "storage_formats", groups = SMOKE)
public void testReadUniontype(String storageFormat)
{
Expand Down Expand Up @@ -137,6 +219,25 @@ public void testReadUniontype(String storageFormat)
}
}

@Test(dataProvider = "union_dereference_test_cases", groups = SMOKE)
public void testReadUniontypeWithDereference(String createTableSql, String insertSql, String selectSql, List<Object> expectedResult, String selectTagSql, List<Object> expectedTagResult, String dropTableSql)
{
// According to testing results, the Hive INSERT queries here only work in Hive 1.2
if (getHiveVersionMajor() != 1 || getHiveVersionMinor() != 2) {
Comment thread
groupcache4321 marked this conversation as resolved.
Outdated
throw new SkipException("This test can only be run with Hive 1.2 (default config)");
}

onHive().executeQuery(createTableSql);
onHive().executeQuery(insertSql);

QueryResult result = onTrino().executeQuery(selectSql);
assertThat(result.column(1)).containsExactlyInAnyOrderElementsOf(expectedResult);
result = onTrino().executeQuery(selectTagSql);
assertThat(result.column(1)).containsExactlyInAnyOrderElementsOf(expectedTagResult);

onTrino().executeQuery(dropTableSql);
}

@Test(dataProvider = "storage_formats", groups = SMOKE)
public void testUnionTypeSchemaEvolution(String storageFormat)
{
Expand Down