Skip to content

Commit cb71c5d

Browse files
groupcache4321yluan
authored andcommitted
Implement predicate push down for parquet nested columns
1 parent 9c9e951 commit cb71c5d

File tree

3 files changed

+125
-15
lines changed

3 files changed

+125
-15
lines changed

plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import org.apache.parquet.io.MessageColumnIO;
5858
import org.apache.parquet.schema.GroupType;
5959
import org.apache.parquet.schema.MessageType;
60+
import org.apache.parquet.schema.Type;
6061
import org.joda.time.DateTimeZone;
6162

6263
import javax.inject.Inject;
@@ -217,9 +218,6 @@ public static ReaderPageSource createPageSource(
217218
Optional<ParquetWriteValidation> parquetWriteValidation,
218219
int domainCompactionThreshold)
219220
{
220-
// Ignore predicates on partial columns for now.
221-
effectivePredicate = effectivePredicate.filter((column, domain) -> column.isBaseColumn());
222-
223221
MessageType fileSchema;
224222
MessageType requestedSchema;
225223
MessageColumnIO messageColumn;
@@ -433,19 +431,30 @@ public static TupleDomain<ColumnDescriptor> getParquetTupleDomain(
433431
continue;
434432
}
435433

436-
ColumnDescriptor descriptor;
437-
if (useColumnNames) {
438-
descriptor = descriptorsByPath.get(ImmutableList.of(columnHandle.getName()));
434+
ColumnDescriptor descriptor = null;
435+
436+
Optional<org.apache.parquet.schema.Type> baseColumnType = getBaseColumnParquetType(columnHandle, fileSchema, useColumnNames);
437+
// failed to look up the column from the file schema
438+
if (baseColumnType.isEmpty()) {
439+
continue;
440+
}
441+
else if (columnHandle.getHiveColumnProjectionInfo().isEmpty() && baseColumnType.get().isPrimitive()) {
442+
descriptor = descriptorsByPath.get(ImmutableList.of(baseColumnType.get().getName()));
439443
}
440-
else {
441-
Optional<org.apache.parquet.schema.Type> parquetField = getBaseColumnParquetType(columnHandle, fileSchema, false);
442-
if (parquetField.isEmpty() || !parquetField.get().isPrimitive()) {
443-
// Parquet file has fewer column than partition
444-
// Or the field is a complex type
444+
else if (columnHandle.getHiveColumnProjectionInfo().isPresent() && !baseColumnType.get().isPrimitive()) {
445+
Optional<List<Type>> subfieldTypes = dereferenceSubFieldTypes(baseColumnType.get().asGroupType(), columnHandle.getHiveColumnProjectionInfo().get());
446+
// failed to look up subfields from the file schema
447+
if (subfieldTypes.isEmpty()) {
445448
continue;
446449
}
447-
descriptor = descriptorsByPath.get(ImmutableList.of(parquetField.get().getName()));
450+
451+
ImmutableList.Builder<String> path = ImmutableList.builder();
452+
path.add(baseColumnType.get().getName());
453+
path.addAll(subfieldTypes.get().stream().map(Type::getName).toList());
454+
455+
descriptor = descriptorsByPath.get(path.build());
448456
}
457+
449458
if (descriptor != null) {
450459
predicate.put(descriptor, entry.getValue());
451460
}

plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5289,11 +5289,11 @@ public void testParquetOnlyNullsRowGroupPruning()
52895289
// Nested column `b` also has nulls count of 4096, but it contains non nulls as well
52905290
assertUpdate("CREATE TABLE " + tableName + " (col ROW(a BIGINT, b ARRAY(DOUBLE))) WITH (format = 'PARQUET')");
52915291
assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(repeat(1, 4096), x -> ROW(ROW(NULL, ARRAY [NULL, rand()]))))", 4096);
5292-
// TODO replace with assertNoDataRead after nested column predicate pushdown
5292+
52935293
assertQueryStats(
52945294
getSession(),
52955295
"SELECT * FROM " + tableName + " WHERE col.a IS NOT NULL",
5296-
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0),
5296+
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isEqualTo(0),
52975297
results -> assertThat(results.getRowCount()).isEqualTo(0));
52985298
assertQueryStats(
52995299
getSession(),
@@ -5302,6 +5302,32 @@ public void testParquetOnlyNullsRowGroupPruning()
53025302
results -> assertThat(results.getRowCount()).isEqualTo(4096));
53035303
}
53045304

5305+
@Test
5306+
public void testParquetNestedRowGroupPruning()
5307+
{
5308+
String tableName = "test_primitive_column_nested_pruning_" + randomNameSuffix();
5309+
assertUpdate("CREATE TABLE " + tableName + " (col BIGINT) WITH (format = 'PARQUET')");
5310+
assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(repeat(300, 4096))", 4096);
5311+
assertNoDataRead("SELECT * FROM " + tableName + " WHERE col != 300");
5312+
5313+
tableName = "test_nested_column_nulls_pruning_" + randomNameSuffix();
5314+
// Nested column `a` has nulls count of 4096 and contains only nulls
5315+
// Nested column `b` also has nulls count of 4096, but it contains non nulls as well
5316+
assertUpdate("CREATE TABLE " + tableName + " (col ROW(a BIGINT, b BIGINT)) WITH (format = 'PARQUET')");
5317+
assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(repeat(1, 4096), x -> ROW(ROW(300, 500))))", 4096);
5318+
5319+
assertQueryStats(
5320+
getSession(),
5321+
"SELECT * FROM " + tableName + " WHERE col.a != 300",
5322+
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isEqualTo(0),
5323+
results -> assertThat(results.getRowCount()).isEqualTo(0));
5324+
assertQueryStats(
5325+
getSession(),
5326+
"SELECT * FROM " + tableName + " WHERE col.b = 500",
5327+
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0),
5328+
results -> assertThat(results.getRowCount()).isEqualTo(4096));
5329+
}
5330+
53055331
private void assertNoDataRead(@Language("SQL") String sql)
53065332
{
53075333
assertQueryStats(

plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import com.google.common.collect.ImmutableMap;
1818
import com.google.common.collect.Iterables;
1919
import io.trino.plugin.hive.HiveColumnHandle;
20+
import io.trino.plugin.hive.HiveColumnProjectionInfo;
2021
import io.trino.plugin.hive.HiveType;
2122
import io.trino.spi.predicate.Domain;
2223
import io.trino.spi.predicate.TupleDomain;
@@ -122,12 +123,86 @@ public void testParquetTupleDomainStruct(boolean useColumnNames)
122123
MessageType fileSchema = new MessageType("hive_schema",
123124
new GroupType(OPTIONAL, "my_struct",
124125
new PrimitiveType(OPTIONAL, INT32, "a"),
125-
new PrimitiveType(OPTIONAL, INT32, "b")));
126+
new PrimitiveType(OPTIONAL, INT32, "b"),
127+
new PrimitiveType(OPTIONAL, INT32, "c")));
126128
Map<List<String>, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema);
127129
TupleDomain<ColumnDescriptor> tupleDomain = getParquetTupleDomain(descriptorsByPath, domain, fileSchema, useColumnNames);
128130
assertTrue(tupleDomain.isAll());
129131
}
130132

133+
@Test(dataProvider = "useColumnNames")
134+
public void testParquetTupleDomainStructNestedColumn(boolean useColumNames)
135+
{
136+
RowType baseType = rowType(
137+
RowType.field("a", INTEGER),
138+
RowType.field("b", INTEGER),
139+
RowType.field("c", INTEGER));
140+
141+
HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo(
142+
ImmutableList.of(1),
143+
ImmutableList.of("b"),
144+
HiveType.HIVE_INT,
145+
INTEGER);
146+
147+
HiveColumnHandle projectedColumn = new HiveColumnHandle(
148+
"row_field",
149+
0,
150+
HiveType.toHiveType(baseType),
151+
baseType,
152+
Optional.of(columnProjectionInfo),
153+
REGULAR,
154+
Optional.empty());
155+
156+
Domain predicateDomain = Domain.singleValue(INTEGER, 123L);
157+
TupleDomain<HiveColumnHandle> tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain));
158+
159+
MessageType fileSchema = new MessageType("hive_schema",
160+
new GroupType(OPTIONAL, "row_field",
161+
new PrimitiveType(OPTIONAL, INT32, "a"),
162+
new PrimitiveType(OPTIONAL, INT32, "b"),
163+
new PrimitiveType(OPTIONAL, INT32, "c")));
164+
Map<List<String>, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema);
165+
TupleDomain<ColumnDescriptor> calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumNames);
166+
assertEquals(calculatedTupleDomain.getDomains().get().size(), 1);
167+
ColumnDescriptor selectedColumnDescriptor = descriptorsByPath.get(ImmutableList.of("row_field", "b"));
168+
assertEquals(calculatedTupleDomain.getDomains().get().get(selectedColumnDescriptor), (predicateDomain));
169+
}
170+
171+
@Test(dataProvider = "useColumnNames")
172+
public void testParquetTupleDomainStructNestedColumnNonExist(boolean useColumnNames)
173+
{
174+
RowType baseType = rowType(
175+
RowType.field("a", INTEGER),
176+
RowType.field("b", INTEGER),
177+
RowType.field("non_exist", INTEGER));
178+
179+
HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo(
180+
ImmutableList.of(2),
181+
ImmutableList.of("non_exist"),
182+
HiveType.HIVE_INT,
183+
INTEGER);
184+
185+
HiveColumnHandle projectedColumn = new HiveColumnHandle(
186+
"row_field",
187+
0,
188+
HiveType.toHiveType(baseType),
189+
baseType,
190+
Optional.of(columnProjectionInfo),
191+
REGULAR,
192+
Optional.empty());
193+
194+
Domain predicateDomain = Domain.singleValue(INTEGER, 123L);
195+
TupleDomain<HiveColumnHandle> tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain));
196+
197+
MessageType fileSchema = new MessageType("hive_schema",
198+
new GroupType(OPTIONAL, "row_field",
199+
new PrimitiveType(OPTIONAL, INT32, "a"),
200+
new PrimitiveType(OPTIONAL, INT32, "b")));
201+
Map<List<String>, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema);
202+
TupleDomain<ColumnDescriptor> calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumnNames);
203+
assertTrue(calculatedTupleDomain.isAll());
204+
}
205+
131206
@Test(dataProvider = "useColumnNames")
132207
public void testParquetTupleDomainMap(boolean useColumnNames)
133208
{

0 commit comments

Comments
 (0)