Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,14 @@ public static RelDataType convert(StructTypeInfo structType, final RelDataTypeFa
}

// Mimic the StructTypeInfo conversion to convert a UnionTypeInfo to the corresponding RelDataType
// The schema of output Struct conforms to https://github.com/trinodb/trino/pull/3483
public static RelDataType convert(UnionTypeInfo unionType, RelDataTypeFactory dtFactory) {
List<RelDataType> fTypes = unionType.getAllUnionObjectTypeInfos().stream()
.map(typeInfo -> convert(typeInfo, dtFactory)).collect(Collectors.toList());
List<String> fNames = IntStream.range(0, unionType.getAllUnionObjectTypeInfos().size()).mapToObj(i -> "tag_" + i)
List<String> fNames = IntStream.range(0, unionType.getAllUnionObjectTypeInfos().size()).mapToObj(i -> "field" + i)
.collect(Collectors.toList());
fTypes.add(0, dtFactory.createSqlType(SqlTypeName.TINYINT));
fNames.add(0, "tag");

RelDataType rowType = dtFactory.createStructType(fTypes, fNames);
return dtFactory.createTypeWithNullability(rowType, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
import com.linkedin.coral.common.HiveSchema;
import com.linkedin.coral.common.HiveTable;

import static org.testng.Assert.*;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertTrue;


public class HiveTableTest {
Expand Down Expand Up @@ -65,6 +67,23 @@ public void testTable() throws Exception {
assertEquals(colC.getType().getComponentType().getSqlTypeName(), SqlTypeName.DOUBLE);
}

@Test
public void testTableWithUnion() throws Exception {
final RelDataTypeFactory typeFactory = new JavaTypeFactoryImpl();

// test handling of union
Table unionTable = getTable("default", "union_table");
// union_table:(foo uniontype<int, double, array<string>, struct<a:int,b:string>>)
// expected outcome schema: struct<tag:tinyint, field0:int, field1:double, field2:array<string>, field3:struct<a:int,b:string>>
RelDataType rowType = unionTable.getRowType(typeFactory);
assertNotNull(rowType);

String expectedTypeString =
"RecordType(" + "RecordType(" + "TINYINT tag, INTEGER field0, DOUBLE field1, VARCHAR(65536) ARRAY field2, "
+ "RecordType(INTEGER a, VARCHAR(65536) b) field3" + ") " + "foo)";
assertEquals(rowType.toString(), expectedTypeString);
}

@Test
public void testGetDaliFunctionParams() throws HiveException, TException {
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ public void testSchemaPromotionView() {
assertEquals(CoralSpark.create(relNode).getSparkSql(), targetSql);
}

@Test
@Test(enabled = false)
public void testUnionExtractUDF() {
RelNode relNode = TestUtils.toRelNode("SELECT extract_union(foo) from union_table");
String targetSql = String.join("\n", "SELECT foo", "FROM default.union_table");
Expand Down