Skip to content

Commit 70b8046

Browse files
committed
Add projection push down for STRUCT field in big query connector
1 parent 2ea38b5 commit 70b8046

18 files changed

+551
-34
lines changed

plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryArrowToPageConverter.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,32 @@ public void convert(PageBuilder pageBuilder, ArrowRecordBatch batch)
103103

104104
for (int column = 0; column < columns.size(); column++) {
105105
BigQueryColumnHandle columnHandle = columns.get(column);
106+
FieldVector fieldVector = getFieldVector(root, columnHandle);
106107
convertType(pageBuilder.getBlockBuilder(column),
107108
columnHandle.trinoType(),
108-
root.getVector(toBigQueryColumnName(columnHandle.name())),
109+
fieldVector,
109110
0,
110-
root.getVector(toBigQueryColumnName(columnHandle.name())).getValueCount());
111+
fieldVector.getValueCount());
111112
}
112113

113114
root.clear();
114115
}
115116

117+
private static FieldVector getFieldVector(VectorSchemaRoot root, BigQueryColumnHandle columnHandle)
118+
{
119+
FieldVector fieldVector = root.getVector(toBigQueryColumnName(columnHandle.name()));
120+
121+
for (String dereferenceName : columnHandle.dereferenceNames()) {
122+
for (FieldVector child : fieldVector.getChildrenFromFields()) {
123+
if (child.getField().getName().equals(dereferenceName)) {
124+
fieldVector = child;
125+
break;
126+
}
127+
}
128+
}
129+
return fieldVector;
130+
}
131+
116132
private void convertType(BlockBuilder output, Type type, FieldVector vector, int offset, int length)
117133
{
118134
Class<?> javaType = type.getJavaType();

plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import com.google.cloud.bigquery.TableInfo;
3434
import com.google.cloud.bigquery.TableResult;
3535
import com.google.cloud.http.BaseHttpServiceException;
36+
import com.google.common.base.Joiner;
3637
import com.google.common.cache.Cache;
3738
import com.google.common.cache.CacheLoader;
3839
import com.google.common.cache.LoadingCache;
@@ -468,8 +469,17 @@ public TableId getDestinationTable(String sql)
468469

469470
public static String selectSql(TableId table, List<BigQueryColumnHandle> requiredColumns, Optional<String> filter)
470471
{
471-
String columns = requiredColumns.stream().map(column -> format("`%s`", column.name())).collect(joining(","));
472-
return selectSql(table, columns, filter);
472+
return selectSql(table,
473+
requiredColumns.stream()
474+
.map(column -> Joiner.on('.')
475+
.join(ImmutableList.<String>builder()
476+
.add(format("`%s`", column.name()))
477+
.addAll(column.dereferenceNames().stream()
478+
.map(dereferenceName -> format("`%s`", dereferenceName))
479+
.collect(toImmutableList()))
480+
.build()))
481+
.collect(joining(",")),
482+
filter);
473483
}
474484

475485
public static String selectSql(TableId table, String formattedColumns, Optional<String> filter)

plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryColumnHandle.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import com.fasterxml.jackson.annotation.JsonIgnore;
1717
import com.google.cloud.bigquery.Field;
1818
import com.google.cloud.bigquery.StandardSQLTypeName;
19+
import com.google.common.base.Joiner;
1920
import com.google.common.collect.ImmutableList;
2021
import io.trino.spi.connector.ColumnHandle;
2122
import io.trino.spi.connector.ColumnMetadata;
@@ -30,6 +31,7 @@
3031

3132
public record BigQueryColumnHandle(
3233
String name,
34+
List<String> dereferenceNames,
3335
Type trinoType,
3436
StandardSQLTypeName bigqueryType,
3537
boolean isPushdownSupported,
@@ -44,6 +46,7 @@ public record BigQueryColumnHandle(
4446
public BigQueryColumnHandle
4547
{
4648
requireNonNull(name, "name is null");
49+
dereferenceNames = ImmutableList.copyOf(requireNonNull(dereferenceNames, "dereferenceNames is null"));
4750
requireNonNull(trinoType, "trinoType is null");
4851
requireNonNull(bigqueryType, "bigqueryType is null");
4952
requireNonNull(mode, "mode is null");
@@ -62,6 +65,16 @@ public ColumnMetadata getColumnMetadata()
6265
.build();
6366
}
6467

68+
@JsonIgnore
69+
public String getQualifiedName()
70+
{
71+
return Joiner.on('.')
72+
.join(ImmutableList.<String>builder()
73+
.add(name)
74+
.addAll(dereferenceNames)
75+
.build());
76+
}
77+
6578
@JsonIgnore
6679
public long getRetainedSizeInBytes()
6780
{

plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryConfig.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ public class BigQueryConfig
6363
private String queryLabelName;
6464
private String queryLabelFormat;
6565
private boolean proxyEnabled;
66+
private boolean projectionPushDownEnabled = true;
6667
private int metadataParallelism = 2;
6768

6869
public Optional<String> getProjectId()
@@ -342,6 +343,19 @@ public BigQueryConfig setProxyEnabled(boolean proxyEnabled)
342343
return this;
343344
}
344345

346+
public boolean isProjectionPushdownEnabled()
347+
{
348+
return projectionPushDownEnabled;
349+
}
350+
351+
@Config("bigquery.projection-pushdown-enabled")
352+
@ConfigDescription("Dereference push down for ROW type")
353+
public BigQueryConfig setProjectionPushdownEnabled(boolean projectionPushDownEnabled)
354+
{
355+
this.projectionPushDownEnabled = projectionPushDownEnabled;
356+
return this;
357+
}
358+
345359
@Min(1)
346360
@Max(32)
347361
public int getMetadataParallelism()

plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java

Lines changed: 154 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,18 @@
3232
import com.google.cloud.bigquery.storage.v1.JsonStreamWriter;
3333
import com.google.cloud.bigquery.storage.v1.TableName;
3434
import com.google.cloud.bigquery.storage.v1.WriteStream;
35+
import com.google.common.annotations.VisibleForTesting;
3536
import com.google.common.base.Functions;
3637
import com.google.common.collect.ImmutableList;
3738
import com.google.common.collect.ImmutableMap;
3839
import com.google.common.collect.ImmutableSet;
40+
import com.google.common.collect.Ordering;
3941
import com.google.common.io.Closer;
4042
import com.google.common.util.concurrent.ListenableFuture;
4143
import com.google.common.util.concurrent.ListeningExecutorService;
4244
import io.airlift.log.Logger;
4345
import io.airlift.slice.Slice;
46+
import io.trino.plugin.base.projection.ApplyProjectionUtil;
4447
import io.trino.plugin.bigquery.BigQueryClient.RemoteDatabaseObject;
4548
import io.trino.plugin.bigquery.BigQueryTableHandle.BigQueryPartitionType;
4649
import io.trino.plugin.bigquery.ptf.Query.QueryHandle;
@@ -78,18 +81,22 @@
7881
import io.trino.spi.connector.TableFunctionApplicationResult;
7982
import io.trino.spi.connector.TableNotFoundException;
8083
import io.trino.spi.expression.ConnectorExpression;
84+
import io.trino.spi.expression.Variable;
8185
import io.trino.spi.function.table.ConnectorTableFunctionHandle;
8286
import io.trino.spi.predicate.Domain;
8387
import io.trino.spi.predicate.TupleDomain;
8488
import io.trino.spi.security.TrinoPrincipal;
8589
import io.trino.spi.statistics.ComputedStatistics;
8690
import io.trino.spi.type.BigintType;
91+
import io.trino.spi.type.RowType;
8792
import io.trino.spi.type.Type;
8893
import io.trino.spi.type.VarcharType;
8994
import org.json.JSONArray;
9095

9196
import java.io.IOException;
97+
import java.util.ArrayList;
9298
import java.util.Collection;
99+
import java.util.Comparator;
93100
import java.util.HashMap;
94101
import java.util.Iterator;
95102
import java.util.List;
@@ -108,16 +115,22 @@
108115
import static com.google.cloud.bigquery.storage.v1.WriteStream.Type.COMMITTED;
109116
import static com.google.common.base.Preconditions.checkArgument;
110117
import static com.google.common.base.Preconditions.checkState;
118+
import static com.google.common.base.Verify.verify;
111119
import static com.google.common.collect.ImmutableList.toImmutableList;
112120
import static com.google.common.collect.ImmutableMap.toImmutableMap;
121+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
113122
import static com.google.common.util.concurrent.Futures.allAsList;
114123
import static io.trino.plugin.base.TemporaryTables.generateTemporaryTableName;
124+
import static io.trino.plugin.base.projection.ApplyProjectionUtil.ProjectedColumnRepresentation;
125+
import static io.trino.plugin.base.projection.ApplyProjectionUtil.extractSupportedProjectedColumns;
126+
import static io.trino.plugin.base.projection.ApplyProjectionUtil.replaceWithNewVariables;
115127
import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_BAD_WRITE;
116128
import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_FAILED_TO_EXECUTE_QUERY;
117129
import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_LISTING_TABLE_ERROR;
118130
import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_UNSUPPORTED_OPERATION;
119131
import static io.trino.plugin.bigquery.BigQueryPseudoColumn.PARTITION_DATE;
120132
import static io.trino.plugin.bigquery.BigQueryPseudoColumn.PARTITION_TIME;
133+
import static io.trino.plugin.bigquery.BigQuerySessionProperties.isProjectionPushdownEnabled;
121134
import static io.trino.plugin.bigquery.BigQueryTableHandle.BigQueryPartitionType.INGESTION;
122135
import static io.trino.plugin.bigquery.BigQueryTableHandle.getPartitionType;
123136
import static io.trino.plugin.bigquery.BigQueryUtil.isWildcardTable;
@@ -138,6 +151,8 @@ public class BigQueryMetadata
138151
{
139152
private static final Logger log = Logger.get(BigQueryMetadata.class);
140153
private static final Type TRINO_PAGE_SINK_ID_COLUMN_TYPE = BigintType.BIGINT;
154+
private static final Ordering<BigQueryColumnHandle> COLUMN_HANDLE_ORDERING = Ordering
155+
.from(Comparator.comparingInt(columnHandle -> columnHandle.dereferenceNames().size()));
141156

142157
static final int DEFAULT_NUMERIC_TYPE_PRECISION = 38;
143158
static final int DEFAULT_NUMERIC_TYPE_SCALE = 9;
@@ -771,7 +786,7 @@ public Optional<ConnectorOutputMetadata> finishInsert(
771786
@Override
772787
public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle)
773788
{
774-
return new BigQueryColumnHandle("$merge_row_id", BIGINT, INT64, true, Field.Mode.REQUIRED, ImmutableList.of(), null, true);
789+
return new BigQueryColumnHandle("$merge_row_id", ImmutableList.of(), BIGINT, INT64, true, Field.Mode.REQUIRED, ImmutableList.of(), null, true);
775790
}
776791

777792
@Override
@@ -882,24 +897,150 @@ public Optional<ProjectionApplicationResult<ConnectorTableHandle>> applyProjecti
882897
log.debug("applyProjection(session=%s, handle=%s, projections=%s, assignments=%s)",
883898
session, handle, projections, assignments);
884899
BigQueryTableHandle bigQueryTableHandle = (BigQueryTableHandle) handle;
900+
if (!isProjectionPushdownEnabled(session)) {
901+
List<ColumnHandle> newColumns = ImmutableList.copyOf(assignments.values());
902+
if (bigQueryTableHandle.projectedColumns().isPresent() && containSameElements(newColumns, bigQueryTableHandle.projectedColumns().get())) {
903+
return Optional.empty();
904+
}
885905

886-
List<ColumnHandle> newColumns = ImmutableList.copyOf(assignments.values());
906+
ImmutableList.Builder<BigQueryColumnHandle> projectedColumns = ImmutableList.builder();
907+
ImmutableList.Builder<Assignment> assignmentList = ImmutableList.builder();
908+
assignments.forEach((name, column) -> {
909+
BigQueryColumnHandle columnHandle = (BigQueryColumnHandle) column;
910+
projectedColumns.add(columnHandle);
911+
assignmentList.add(new Assignment(name, column, columnHandle.trinoType()));
912+
});
887913

888-
if (bigQueryTableHandle.projectedColumns().isPresent() && containSameElements(newColumns, bigQueryTableHandle.projectedColumns().get())) {
889-
return Optional.empty();
914+
bigQueryTableHandle = bigQueryTableHandle.withProjectedColumns(projectedColumns.build());
915+
916+
return Optional.of(new ProjectionApplicationResult<>(bigQueryTableHandle, projections, assignmentList.build(), false));
890917
}
891918

892-
ImmutableList.Builder<BigQueryColumnHandle> projectedColumns = ImmutableList.builder();
893-
ImmutableList.Builder<Assignment> assignmentList = ImmutableList.builder();
894-
assignments.forEach((name, column) -> {
895-
BigQueryColumnHandle columnHandle = (BigQueryColumnHandle) column;
896-
projectedColumns.add(columnHandle);
897-
assignmentList.add(new Assignment(name, column, columnHandle.trinoType()));
898-
});
919+
// Create projected column representations for supported sub expressions. Simple column references and chain of
920+
// dereferences on a variable are supported right now.
921+
Set<ConnectorExpression> projectedExpressions = projections.stream()
922+
.flatMap(expression -> extractSupportedProjectedColumns(expression).stream())
923+
.collect(toImmutableSet());
924+
925+
Map<ConnectorExpression, ProjectedColumnRepresentation> columnProjections = projectedExpressions.stream()
926+
.collect(toImmutableMap(identity(), ApplyProjectionUtil::createProjectedColumnRepresentation));
927+
928+
// all references are simple variables
929+
if (columnProjections.values().stream().allMatch(ProjectedColumnRepresentation::isVariable)) {
930+
Set<BigQueryColumnHandle> projectedColumns = ImmutableSet.copyOf(projectParentColumns(assignments.values().stream()
931+
.map(BigQueryColumnHandle.class::cast)
932+
.collect(toImmutableList())));
933+
if (bigQueryTableHandle.projectedColumns().isPresent() && containSameElements(projectedColumns, bigQueryTableHandle.projectedColumns().get())) {
934+
return Optional.empty();
935+
}
936+
List<Assignment> assignmentsList = assignments.entrySet().stream()
937+
.map(assignment -> new Assignment(
938+
assignment.getKey(),
939+
assignment.getValue(),
940+
((BigQueryColumnHandle) assignment.getValue()).trinoType()))
941+
.collect(toImmutableList());
942+
943+
return Optional.of(new ProjectionApplicationResult<>(
944+
bigQueryTableHandle.withProjectedColumns(ImmutableList.copyOf(projectedColumns)),
945+
projections,
946+
assignmentsList,
947+
false));
948+
}
949+
950+
Map<String, Assignment> newAssignments = new HashMap<>();
951+
ImmutableMap.Builder<ConnectorExpression, Variable> newVariablesBuilder = ImmutableMap.builder();
952+
ImmutableSet.Builder<BigQueryColumnHandle> projectedColumnsBuilder = ImmutableSet.builder();
953+
954+
for (Map.Entry<ConnectorExpression, ProjectedColumnRepresentation> entry : columnProjections.entrySet()) {
955+
ConnectorExpression expression = entry.getKey();
956+
ProjectedColumnRepresentation projectedColumn = entry.getValue();
899957

900-
bigQueryTableHandle = bigQueryTableHandle.withProjectedColumns(projectedColumns.build());
958+
BigQueryColumnHandle baseColumnHandle = (BigQueryColumnHandle) assignments.get(projectedColumn.getVariable().getName());
959+
BigQueryColumnHandle projectedColumnHandle = createProjectedColumnHandle(baseColumnHandle, projectedColumn.getDereferenceIndices(), expression.getType());
960+
String projectedColumnName = projectedColumnHandle.getQualifiedName();
961+
962+
Variable projectedColumnVariable = new Variable(projectedColumnName, expression.getType());
963+
Assignment newAssignment = new Assignment(projectedColumnName, projectedColumnHandle, expression.getType());
964+
newAssignments.putIfAbsent(projectedColumnName, newAssignment);
965+
966+
newVariablesBuilder.put(expression, projectedColumnVariable);
967+
projectedColumnsBuilder.add(projectedColumnHandle);
968+
}
969+
970+
// Modify projections to refer to new variables
971+
Map<ConnectorExpression, Variable> newVariables = newVariablesBuilder.buildOrThrow();
972+
List<ConnectorExpression> newProjections = projections.stream()
973+
.map(expression -> replaceWithNewVariables(expression, newVariables))
974+
.collect(toImmutableList());
975+
976+
List<Assignment> outputAssignments = newAssignments.values().stream().collect(toImmutableList());
977+
return Optional.of(new ProjectionApplicationResult<>(
978+
bigQueryTableHandle.withProjectedColumns(projectParentColumns(ImmutableList.copyOf(projectedColumnsBuilder.build()))),
979+
newProjections,
980+
outputAssignments,
981+
false));
982+
}
901983

902-
return Optional.of(new ProjectionApplicationResult<>(bigQueryTableHandle, projections, assignmentList.build(), false));
984+
/**
985+
* Creates a set of parent columns for the input projected columns. For example,
986+
* if input {@param columns} include columns "a.b" and "a.b.c", then they will be projected from a single column "a.b".
987+
*/
988+
@VisibleForTesting
989+
static List<BigQueryColumnHandle> projectParentColumns(List<BigQueryColumnHandle> columnHandles)
990+
{
991+
List<BigQueryColumnHandle> sortedColumnHandles = COLUMN_HANDLE_ORDERING.sortedCopy(columnHandles);
992+
List<BigQueryColumnHandle> parentColumns = new ArrayList<>();
993+
for (BigQueryColumnHandle column : sortedColumnHandles) {
994+
if (!parentColumnExists(parentColumns, column)) {
995+
parentColumns.add(column);
996+
}
997+
}
998+
return parentColumns;
999+
}
1000+
1001+
private static boolean parentColumnExists(List<BigQueryColumnHandle> existingColumns, BigQueryColumnHandle column)
1002+
{
1003+
for (BigQueryColumnHandle existingColumn : existingColumns) {
1004+
List<String> existingColumnDereferenceNames = existingColumn.dereferenceNames();
1005+
verify(
1006+
column.dereferenceNames().size() >= existingColumnDereferenceNames.size(),
1007+
"Selected column's dereference size must be greater than or equal to the existing column's dereference size");
1008+
if (existingColumn.name().equals(column.name())
1009+
&& column.dereferenceNames().subList(0, existingColumnDereferenceNames.size()).equals(existingColumnDereferenceNames)) {
1010+
return true;
1011+
}
1012+
}
1013+
return false;
1014+
}
1015+
1016+
private BigQueryColumnHandle createProjectedColumnHandle(BigQueryColumnHandle baseColumn, List<Integer> indices, Type projectedColumnType)
1017+
{
1018+
if (indices.isEmpty()) {
1019+
return baseColumn;
1020+
}
1021+
1022+
ImmutableList.Builder<String> dereferenceNamesBuilder = ImmutableList.builder();
1023+
dereferenceNamesBuilder.addAll(baseColumn.dereferenceNames());
1024+
1025+
Type type = baseColumn.trinoType();
1026+
for (int index : indices) {
1027+
checkArgument(type instanceof RowType, "type should be Row type");
1028+
RowType rowType = (RowType) type;
1029+
RowType.Field field = rowType.getFields().get(index);
1030+
dereferenceNamesBuilder.add(field.getName()
1031+
.orElseThrow(() -> new TrinoException(NOT_SUPPORTED, "ROW type does not have field names declared: " + rowType)));
1032+
type = field.getType();
1033+
}
1034+
return new BigQueryColumnHandle(
1035+
baseColumn.name(),
1036+
dereferenceNamesBuilder.build(),
1037+
projectedColumnType,
1038+
typeManager.toStandardSqlTypeName(projectedColumnType),
1039+
baseColumn.isPushdownSupported(),
1040+
baseColumn.mode(),
1041+
baseColumn.subColumns(),
1042+
baseColumn.description(),
1043+
baseColumn.hidden());
9031044
}
9041045

9051046
@Override

0 commit comments

Comments
 (0)