diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetPageSourceFactory.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetPageSourceFactory.java index 7bb148593706..8c8b492ec843 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetPageSourceFactory.java @@ -22,6 +22,7 @@ import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.PrimitiveType; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.Optional; @@ -36,8 +37,8 @@ public class TestParquetPageSourceFactory { - @Test - public void testGetNestedMixedRepetitionColumnType() + @Test(dataProvider = "useColumnNames") + public void testGetNestedMixedRepetitionColumnType(boolean useColumnNames) { RowType rowType = rowType( RowType.field( @@ -64,7 +65,16 @@ public void testGetNestedMixedRepetitionColumnType() new GroupType(OPTIONAL, "optional_level2", new PrimitiveType(REQUIRED, INT32, "required_level3")))); assertEquals( - ParquetPageSourceFactory.getColumnType(columnHandle, fileSchema, true).get(), + ParquetPageSourceFactory.getColumnType(columnHandle, fileSchema, useColumnNames).get(), fileSchema.getType("optional_level1")); } + + @DataProvider + public Object[][] useColumnNames() + { + return new Object[][] { + {true}, // use column name + {false} // use column index + }; + } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java index 32c0fcf2b28a..ba10daa102c1 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java @@ -28,6 +28,7 @@ import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.PrimitiveType; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.List; @@ -52,8 +53,8 @@ public class TestParquetPredicateUtils { - @Test - public void testParquetTupleDomainPrimitiveArray() + @Test(dataProvider = "useColumnNames") + public void testParquetTupleDomainPrimitiveArray(boolean useColumnNames) { HiveColumnHandle columnHandle = createBaseColumn("my_array", 0, HiveType.valueOf("array"), new ArrayType(INTEGER), REGULAR, Optional.empty()); TupleDomain domain = withColumnDomains(ImmutableMap.of(columnHandle, Domain.notNull(new ArrayType(INTEGER)))); @@ -63,12 +64,12 @@ public void testParquetTupleDomainPrimitiveArray() new GroupType(REPEATED, "bag", new PrimitiveType(OPTIONAL, INT32, "array_element")))); Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema); - TupleDomain tupleDomain = getParquetTupleDomain(descriptorsByPath, domain, fileSchema, true); + TupleDomain tupleDomain = getParquetTupleDomain(descriptorsByPath, domain, fileSchema, useColumnNames); assertTrue(tupleDomain.isAll()); } - @Test - public void testParquetTupleDomainStructArray() + @Test(dataProvider = "useColumnNames") + public void testParquetTupleDomainStructArray(boolean useColumnNames) { RowType.Field rowField = new RowType.Field(Optional.of("a"), INTEGER); RowType rowType = RowType.from(ImmutableList.of(rowField)); @@ -83,12 +84,12 @@ public void testParquetTupleDomainStructArray() new GroupType(OPTIONAL, "array_element", new PrimitiveType(OPTIONAL, INT32, "a"))))); Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema); - TupleDomain tupleDomain = getParquetTupleDomain(descriptorsByPath, domain, fileSchema, true); + TupleDomain tupleDomain = getParquetTupleDomain(descriptorsByPath, domain, fileSchema, useColumnNames); assertTrue(tupleDomain.isAll()); } - @Test - public void testParquetTupleDomainPrimitive() + @Test(dataProvider = "useColumnNames") + public void testParquetTupleDomainPrimitive(boolean useColumnNames) { HiveColumnHandle columnHandle = createBaseColumn("my_primitive", 0, HiveType.valueOf("bigint"), BIGINT, REGULAR, Optional.empty()); Domain singleValueDomain = Domain.singleValue(BIGINT, 123L); @@ -97,7 +98,7 @@ public void testParquetTupleDomainPrimitive() MessageType fileSchema = new MessageType("hive_schema", new PrimitiveType(OPTIONAL, INT64, "my_primitive")); Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema); - TupleDomain tupleDomain = getParquetTupleDomain(descriptorsByPath, domain, fileSchema, true); + TupleDomain tupleDomain = getParquetTupleDomain(descriptorsByPath, domain, fileSchema, useColumnNames); assertEquals(tupleDomain.getDomains().get().size(), 1); ColumnDescriptor descriptor = tupleDomain.getDomains().get().keySet().iterator().next(); @@ -108,8 +109,8 @@ public void testParquetTupleDomainPrimitive() assertEquals(predicateDomain, singleValueDomain); } - @Test - public void testParquetTupleDomainStruct() + @Test(dataProvider = "useColumnNames") + public void testParquetTupleDomainStruct(boolean useColumnNames) { RowType rowType = rowType( RowType.field("a", INTEGER), @@ -123,12 +124,12 @@ public void testParquetTupleDomainStruct() new PrimitiveType(OPTIONAL, INT32, "a"), new PrimitiveType(OPTIONAL, INT32, "b"))); Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema); - TupleDomain tupleDomain = getParquetTupleDomain(descriptorsByPath, domain, fileSchema, true); + TupleDomain tupleDomain = getParquetTupleDomain(descriptorsByPath, domain, fileSchema, useColumnNames); assertTrue(tupleDomain.isAll()); } - @Test - public void testParquetTupleDomainMap() + @Test(dataProvider = "useColumnNames") + public void testParquetTupleDomainMap(boolean useColumnNames) { MapType mapType = new MapType(INTEGER, INTEGER, new TypeOperators()); @@ -143,7 +144,16 @@ public void testParquetTupleDomainMap() new PrimitiveType(OPTIONAL, INT32, "value")))); Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema); - TupleDomain tupleDomain = getParquetTupleDomain(descriptorsByPath, domain, fileSchema, true); + TupleDomain tupleDomain = getParquetTupleDomain(descriptorsByPath, domain, fileSchema, useColumnNames); assertTrue(tupleDomain.isAll()); } + + @DataProvider + public Object[][] useColumnNames() + { + return new Object[][] { + {true}, // use column name + {false} // use column index + }; + } }