diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java index c15365d71daa..d177788fa71c 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java @@ -19,6 +19,7 @@ import com.google.cloud.bigquery.Dataset; import com.google.cloud.bigquery.DatasetId; import com.google.cloud.bigquery.DatasetInfo; +import com.google.cloud.bigquery.FieldValue; import com.google.cloud.bigquery.Job; import com.google.cloud.bigquery.JobException; import com.google.cloud.bigquery.JobInfo; @@ -32,6 +33,7 @@ import com.google.cloud.bigquery.TableInfo; import com.google.cloud.bigquery.TableResult; import com.google.cloud.http.BaseHttpServiceException; +import com.google.common.base.Suppliers; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableSet; @@ -40,6 +42,7 @@ import io.trino.cache.EvictableCacheBuilder; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.RelationColumnsMetadata; import io.trino.spi.connector.RelationCommentMetadata; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.TableNotFoundException; @@ -408,6 +411,34 @@ private static String fullTableName(TableId remoteTableId) return format("%s.%s.%s", remoteTableId.getProject(), remoteTableId.getDataset(), remoteTableId.getTable()); } + public Stream listRelationColumnsMetadata(ConnectorSession session, BigQueryClient client, String projectId, String remoteSchemaName) + { + TableResult result = client.executeQuery(session, """ + SELECT + table_catalog, + table_schema, + table_name, + array_agg(column_name order by ordinal_position), + array_agg(data_type order by ordinal_position), + FROM %s.INFORMATION_SCHEMA.COLUMNS + GROUP BY table_catalog, table_schema, table_name + """.formatted(quote(remoteSchemaName))); + String schemaName = client.toSchemaName(DatasetId.of(projectId, remoteSchemaName)); + return result.streamValues() + .map(row -> { + RemoteTableName remoteTableName = new RemoteTableName( + row.get(0).getStringValue(), + row.get(1).getStringValue(), + row.get(2).getStringValue()); + List names = row.get(3).getRepeatedValue().stream().map(FieldValue::getStringValue).collect(toImmutableList()); + List types = row.get(4).getRepeatedValue().stream().map(FieldValue::getStringValue).collect(toImmutableList()); + verify(names.size() == types.size(), "Mismatched column names and types"); + return RelationColumnsMetadata.forTable( + new SchemaTableName(schemaName, remoteTableName.tableName()), + typeManager.convertToTrinoType(names, types, Suppliers.memoize(() -> getTable(remoteTableName.toTableId())))); + }); + } + public Stream listRelationCommentMetadata(ConnectorSession session, BigQueryClient client, String schemaName) { TableResult result = client.executeQuery(session, """ diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java index 538b5d0306f6..8d9ef9eb95bc 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java @@ -66,6 +66,7 @@ import io.trino.spi.connector.InMemoryRecordSet; import io.trino.spi.connector.ProjectionApplicationResult; import io.trino.spi.connector.RecordCursor; +import io.trino.spi.connector.RelationColumnsMetadata; import io.trino.spi.connector.RelationCommentMetadata; import io.trino.spi.connector.RetryMode; import io.trino.spi.connector.SaveMode; @@ -263,6 +264,50 @@ private List listTablesInRemoteSchema(BigQueryClient client, St return tableNames.build(); } + @Override + public Iterator streamRelationColumns(ConnectorSession session, Optional schemaName, UnaryOperator> relationFilter) + { + if (isLegacyMetadataListing) { + return ConnectorMetadata.super.streamRelationColumns(session, schemaName, relationFilter); + } + BigQueryClient client = bigQueryClientFactory.create(session); + String projectId; + List schemaNames; + if (schemaName.isPresent()) { + DatasetId localDatasetId = client.toDatasetId(schemaName.get()); + projectId = localDatasetId.getProject(); + String remoteSchemaName = getRemoteSchemaName(client, localDatasetId.getProject(), localDatasetId.getDataset()); + schemaNames = List.of(remoteSchemaName); + } + else { + projectId = client.getProjectId(); + schemaNames = listRemoteSchemaNames(session); + } + Map resultsByName = schemaNames.stream() + .flatMap(schema -> listRelationColumnsMetadata(session, client, schema, projectId)) + .collect(toImmutableMap(RelationColumnsMetadata::name, Functions.identity(), (first, _) -> { + log.debug("Filtered out [%s] from list of tables due to ambiguous name", first.name()); + return null; + })); + return relationFilter.apply(resultsByName.keySet()).stream() + .map(resultsByName::get) + .iterator(); + } + + private static Stream listRelationColumnsMetadata(ConnectorSession session, BigQueryClient client, String schema, String projectId) + { + try { + return client.listRelationColumnsMetadata(session, client, projectId, schema); + } + catch (BigQueryException e) { + if (e.getCode() == 404) { + log.debug("Dataset disappeared during listing operation: %s", schema); + return Stream.empty(); + } + throw new TrinoException(BIGQUERY_LISTING_TABLE_ERROR, "Failed to retrieve tables from BigQuery", e); + } + } + @Override public Iterator streamRelationComments(ConnectorSession session, Optional schemaName, UnaryOperator> relationFilter) { @@ -278,7 +323,7 @@ public Iterator streamRelationComments(ConnectorSession }).orElseGet(() -> listSchemaNames(session)); Map resultsByName = schemaNames.stream() .flatMap(schema -> listRelationCommentMetadata(session, client, schema)) - .collect(toImmutableMap(RelationCommentMetadata::name, Functions.identity(), (first, second) -> { + .collect(toImmutableMap(RelationCommentMetadata::name, Functions.identity(), (first, _) -> { log.debug("Filtered out [%s] from list of tables due to ambiguous name", first.name()); return null; })); diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTypeManager.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTypeManager.java index 28d2a3f9eb3a..55b169574606 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTypeManager.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTypeManager.java @@ -17,31 +17,28 @@ import com.google.cloud.bigquery.FieldList; import com.google.cloud.bigquery.LegacySQLTypeName; import com.google.cloud.bigquery.StandardSQLTypeName; +import com.google.cloud.bigquery.TableInfo; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.inject.Inject; import io.airlift.slice.Slice; +import io.trino.plugin.bigquery.type.ArrayTypeInfo; +import io.trino.plugin.bigquery.type.BigDecimalTypeInfo; +import io.trino.plugin.bigquery.type.DecimalTypeInfo; +import io.trino.plugin.bigquery.type.PrimitiveTypeInfo; +import io.trino.plugin.bigquery.type.TypeInfo; +import io.trino.plugin.bigquery.type.UnsupportedTypeException; import io.trino.spi.TrinoException; +import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.type.ArrayType; -import io.trino.spi.type.BigintType; -import io.trino.spi.type.BooleanType; -import io.trino.spi.type.DateType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; -import io.trino.spi.type.DoubleType; import io.trino.spi.type.Int128; -import io.trino.spi.type.IntegerType; import io.trino.spi.type.LongTimestampWithTimeZone; import io.trino.spi.type.RowType; -import io.trino.spi.type.SmallintType; -import io.trino.spi.type.TimeType; -import io.trino.spi.type.TimestampType; -import io.trino.spi.type.TimestampWithTimeZoneType; -import io.trino.spi.type.TinyintType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; import io.trino.spi.type.TypeSignature; -import io.trino.spi.type.VarbinaryType; import io.trino.spi.type.VarcharType; import jakarta.annotation.Nullable; @@ -55,26 +52,42 @@ import java.util.Collections; import java.util.List; import java.util.Optional; -import java.util.stream.Collectors; +import java.util.function.Supplier; import static com.google.cloud.bigquery.Field.Mode.REPEATED; +import static com.google.cloud.bigquery.StandardSQLTypeName.STRUCT; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.bigquery.BigQueryMetadata.DEFAULT_NUMERIC_TYPE_PRECISION; import static io.trino.plugin.bigquery.BigQueryMetadata.DEFAULT_NUMERIC_TYPE_SCALE; +import static io.trino.plugin.bigquery.type.TypeInfoUtils.parseTypeString; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DateType.DATE; import static io.trino.spi.type.DecimalType.createDecimalType; +import static io.trino.spi.type.Decimals.MAX_PRECISION; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.SmallintType.SMALLINT; import static io.trino.spi.type.StandardTypes.JSON; -import static io.trino.spi.type.TimeWithTimeZoneType.DEFAULT_PRECISION; -import static io.trino.spi.type.TimeWithTimeZoneType.createTimeWithTimeZoneType; +import static io.trino.spi.type.TimeType.TIME_MICROS; import static io.trino.spi.type.TimeZoneKey.getTimeZoneKey; +import static io.trino.spi.type.TimestampType.TIMESTAMP_MICROS; +import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS; import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND; import static io.trino.spi.type.Timestamps.MILLISECONDS_PER_SECOND; import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND; import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MILLISECOND; import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_NANOSECOND; import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_SECOND; +import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.TypeSignature.arrayType; +import static io.trino.spi.type.TypeSignatureParameter.typeParameter; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; import static java.lang.Integer.parseInt; import static java.lang.Math.floorDiv; @@ -83,7 +96,6 @@ import static java.lang.String.format; import static java.time.ZoneOffset.UTC; import static java.util.Objects.requireNonNull; -import static java.util.stream.Collectors.toList; public final class BigQueryTypeManager { @@ -102,12 +114,14 @@ public final class BigQueryTypeManager private static final DateTimeFormatter TIME_FORMATTER = DateTimeFormatter.ofPattern("''HH:mm:ss.SSSSSS''"); private static final DateTimeFormatter DATETIME_FORMATTER = DateTimeFormatter.ofPattern("uuuu-MM-dd HH:mm:ss.SSSSSS").withZone(UTC); + private final TypeManager typeManager; private final Type jsonType; @Inject public BigQueryTypeManager(TypeManager typeManager) { - jsonType = requireNonNull(typeManager, "typeManager is null").getType(new TypeSignature(JSON)); + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + jsonType = typeManager.getType(new TypeSignature(JSON)); } private RowType.Field toRawTypeField(String name, Field field) @@ -226,7 +240,7 @@ private Field toInnerField(String name, Type type, boolean repeated, @Nullable S { Field.Builder builder; if (type instanceof RowType) { - builder = Field.newBuilder(name, StandardSQLTypeName.STRUCT, toFieldList((RowType) type)).setDescription(comment); + builder = Field.newBuilder(name, STRUCT, toFieldList((RowType) type)).setDescription(comment); } else { builder = Field.newBuilder(name, toStandardSqlTypeName(type)).setDescription(comment); @@ -250,41 +264,41 @@ private FieldList toFieldList(RowType rowType) private StandardSQLTypeName toStandardSqlTypeName(Type type) { - if (type == BooleanType.BOOLEAN) { + if (type == BOOLEAN) { return StandardSQLTypeName.BOOL; } - if (type == TinyintType.TINYINT || type == SmallintType.SMALLINT || type == IntegerType.INTEGER || type == BigintType.BIGINT) { + if (type == TINYINT || type == SMALLINT || type == INTEGER || type == BIGINT) { return StandardSQLTypeName.INT64; } - if (type == DoubleType.DOUBLE) { + if (type == DOUBLE) { return StandardSQLTypeName.FLOAT64; } if (type instanceof DecimalType) { return StandardSQLTypeName.NUMERIC; } - if (type == DateType.DATE) { + if (type == DATE) { return StandardSQLTypeName.DATE; } - if (type == createTimeWithTimeZoneType(DEFAULT_PRECISION)) { + if (type == TIME_MICROS) { return StandardSQLTypeName.TIME; } - if (type == TimestampType.TIMESTAMP_MICROS) { + if (type == TIMESTAMP_MICROS) { return StandardSQLTypeName.DATETIME; } - if (type == TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS) { + if (type == TIMESTAMP_TZ_MICROS) { return StandardSQLTypeName.TIMESTAMP; } if (type instanceof VarcharType) { return StandardSQLTypeName.STRING; } - if (type == VarbinaryType.VARBINARY) { + if (type == VARBINARY) { return StandardSQLTypeName.BYTES; } if (type instanceof ArrayType) { return StandardSQLTypeName.ARRAY; } if (type instanceof RowType) { - return StandardSQLTypeName.STRUCT; + return STRUCT; } throw new TrinoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName()); } @@ -335,11 +349,11 @@ private Optional convertToTrinoType(Field field) { switch (field.getType().getStandardType()) { case BOOL: - return Optional.of(new ColumnMapping(BooleanType.BOOLEAN, true)); + return Optional.of(new ColumnMapping(BOOLEAN, true)); case INT64: - return Optional.of(new ColumnMapping(BigintType.BIGINT, true)); + return Optional.of(new ColumnMapping(BIGINT, true)); case FLOAT64: - return Optional.of(new ColumnMapping(DoubleType.DOUBLE, true)); + return Optional.of(new ColumnMapping(DOUBLE, true)); case NUMERIC: case BIGNUMERIC: Long precision = field.getPrecision(); @@ -355,31 +369,98 @@ private Optional convertToTrinoType(Field field) case STRING: return Optional.of(new ColumnMapping(createUnboundedVarcharType(), true)); case BYTES: - return Optional.of(new ColumnMapping(VarbinaryType.VARBINARY, true)); + return Optional.of(new ColumnMapping(VARBINARY, true)); case DATE: - return Optional.of(new ColumnMapping(DateType.DATE, true)); + return Optional.of(new ColumnMapping(DATE, true)); case DATETIME: - return Optional.of(new ColumnMapping(TimestampType.TIMESTAMP_MICROS, true)); + return Optional.of(new ColumnMapping(TIMESTAMP_MICROS, true)); case TIME: - return Optional.of(new ColumnMapping(TimeType.TIME_MICROS, true)); + return Optional.of(new ColumnMapping(TIME_MICROS, true)); case TIMESTAMP: - return Optional.of(new ColumnMapping(TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS, true)); + return Optional.of(new ColumnMapping(TIMESTAMP_TZ_MICROS, true)); case GEOGRAPHY: - return Optional.of(new ColumnMapping(VarcharType.VARCHAR, false)); + return Optional.of(new ColumnMapping(VARCHAR, false)); case JSON: return Optional.of(new ColumnMapping(jsonType, false)); case STRUCT: - // create the row - FieldList subTypes = field.getSubFields(); - checkArgument(!subTypes.isEmpty(), "a record or struct must have sub-fields"); - List fields = subTypes.stream().map(subField -> toRawTypeField(subField.getName(), subField)).collect(toList()); - RowType rowType = RowType.from(fields); - return Optional.of(new ColumnMapping(rowType, false)); + return Optional.of(new ColumnMapping(createRowType(field), false)); default: return Optional.empty(); } } + private RowType createRowType(Field field) + { + FieldList subTypes = field.getSubFields(); + checkArgument(!subTypes.isEmpty(), "a record or struct must have sub-fields"); + List fields = subTypes.stream() + .map(subField -> toRawTypeField(subField.getName(), subField)) + .collect(toImmutableList()); + return RowType.from(fields); + } + + public List convertToTrinoType(List names, List types) + { + return convertToTrinoType(names, types, Optional::empty); + } + + public List convertToTrinoType(List names, List types, Supplier> tableSupplier) + { + checkArgument(names.size() == types.size(), "Mismatched column names and types"); + + ImmutableList.Builder columns = ImmutableList.builder(); + for (int i = 0; i < names.size(); i++) { + String name = names.get(i); + TypeSignature typeSignature; + try { + TypeInfo typeInfo = parseTypeString(types.get(i)); + typeSignature = toTypeSignature(typeInfo); + } + catch (UnsupportedTypeException e) { + Optional table = tableSupplier.get(); + if (!e.getTypeName().equals(STRUCT) || table.isEmpty() || table.get().getDefinition().getSchema() == null) { + // ignore unsupported types + continue; + } + typeSignature = createRowType(table.get().getDefinition().getSchema().getFields().get(name)).getTypeSignature(); + } + catch (TrinoException | IllegalArgumentException e) { + // ignore unsupported types + continue; + } + columns.add(new ColumnMetadata(name, typeManager.getType(typeSignature))); + } + return columns.build(); + } + + private TypeSignature toTypeSignature(TypeInfo typeInfo) + { + return switch (typeInfo) { + case DecimalTypeInfo decimalTypeInfo: + yield createDecimalType(decimalTypeInfo.precision(), decimalTypeInfo.scale()).getTypeSignature(); + case BigDecimalTypeInfo decimalTypeInfo: + yield createDecimalType(decimalTypeInfo.precision(), decimalTypeInfo.scale()).getTypeSignature(); + case PrimitiveTypeInfo primitiveTypeInfo: + Type type = switch (primitiveTypeInfo.getStandardSqlTypeName()) { + case BOOL -> BOOLEAN; + case INT64 -> BIGINT; + case FLOAT64 -> DOUBLE; + case STRING -> VARCHAR; + case BYTES -> VARBINARY; + case DATE -> DATE; + case DATETIME -> TIMESTAMP_MICROS; + case TIMESTAMP -> TIMESTAMP_TZ_MICROS; + case GEOGRAPHY -> VARCHAR; + case JSON -> jsonType; + default -> throw new IllegalArgumentException("Unsupported type: " + primitiveTypeInfo); + }; + yield type.getTypeSignature(); + case ArrayTypeInfo arrayTypeInfo: + TypeSignature elementType = toTypeSignature(arrayTypeInfo.getListElementTypeInfo()); + yield arrayType(typeParameter(elementType)); + }; + } + public BigQueryColumnHandle toColumnHandle(Field field) { FieldList subFields = field.getSubFields(); @@ -388,7 +469,7 @@ public BigQueryColumnHandle toColumnHandle(Field field) subFields.stream() .filter(this::isSupportedType) .map(this::toColumnHandle) - .collect(Collectors.toList()); + .collect(toImmutableList()); ColumnMapping columnMapping = toTrinoType(field).orElseThrow(() -> new IllegalArgumentException("Unsupported type: " + field)); return new BigQueryColumnHandle( field.getName(), @@ -409,7 +490,7 @@ public boolean isSupportedType(Field field) if (field.getPrecision() == null && field.getScale() == null) { return false; } - if (field.getPrecision() != null && field.getPrecision() > Decimals.MAX_PRECISION) { + if (field.getPrecision() != null && field.getPrecision() > MAX_PRECISION) { return false; } } diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/ArrayTypeInfo.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/ArrayTypeInfo.java new file mode 100644 index 000000000000..66b3e759e337 --- /dev/null +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/ArrayTypeInfo.java @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.bigquery.type; + +import static com.google.cloud.bigquery.StandardSQLTypeName.ARRAY; +import static java.util.Objects.requireNonNull; + +public final class ArrayTypeInfo + extends TypeInfo +{ + private final TypeInfo elementTypeInfo; + + ArrayTypeInfo(TypeInfo elementTypeInfo) + { + this.elementTypeInfo = requireNonNull(elementTypeInfo, "elementTypeInfo is null"); + } + + @Override + public String toString() + { + return ARRAY + "<" + elementTypeInfo + ">"; + } + + public TypeInfo getListElementTypeInfo() + { + return elementTypeInfo; + } +} diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/BigDecimalTypeInfo.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/BigDecimalTypeInfo.java new file mode 100644 index 000000000000..e0ed95be3e3c --- /dev/null +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/BigDecimalTypeInfo.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.bigquery.type; + +import static com.google.cloud.bigquery.StandardSQLTypeName.BIGNUMERIC; +import static com.google.common.base.Preconditions.checkArgument; + +public final class BigDecimalTypeInfo + extends PrimitiveTypeInfo +{ + private static final int MAX_PRECISION_MINUS_SCALE = 38; + private static final int MAX_SCALE = 38; + + private final int precision; + private final int scale; + + public BigDecimalTypeInfo(int precision, int scale) + { + super(BIGNUMERIC.name()); + this.precision = precision; + this.scale = scale; + checkArgument(scale >= 0 && scale <= MAX_SCALE, "invalid decimal scale: %s", scale); + checkArgument(precision >= 1 && precision <= MAX_PRECISION_MINUS_SCALE + scale, "invalid decimal precision: %s", precision); + checkArgument(scale <= precision, "invalid decimal precision: %s is lower than scale %s", precision, scale); + } + + @Override + public String toString() + { + return decimalTypeName(precision, scale); + } + + public int precision() + { + return precision; + } + + public int scale() + { + return scale; + } + + public static String decimalTypeName(int precision, int scale) + { + return BIGNUMERIC.name() + "(" + precision + ", " + scale + ")"; + } +} diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/DecimalTypeInfo.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/DecimalTypeInfo.java new file mode 100644 index 000000000000..d3385c4e6fb7 --- /dev/null +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/DecimalTypeInfo.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.bigquery.type; + +import static com.google.cloud.bigquery.StandardSQLTypeName.NUMERIC; +import static com.google.common.base.Preconditions.checkArgument; + +public final class DecimalTypeInfo + extends PrimitiveTypeInfo +{ + private static final int MAX_PRECISION_MINUS_SCALE = 29; + private static final int MAX_SCALE = 9; + + private final int precision; + private final int scale; + + public DecimalTypeInfo(int precision, int scale) + { + super(NUMERIC.name()); + this.precision = precision; + this.scale = scale; + checkArgument(scale >= 0 && scale <= MAX_SCALE, "invalid decimal scale: %s", scale); + checkArgument(precision >= 1 && precision <= MAX_PRECISION_MINUS_SCALE + scale, "invalid decimal precision: %s", precision); + checkArgument(scale <= precision, "invalid decimal precision: %s is lower than scale %s", precision, scale); + } + + @Override + public String toString() + { + return decimalTypeName(precision, scale); + } + + public int precision() + { + return precision; + } + + public int scale() + { + return scale; + } + + public static String decimalTypeName(int precision, int scale) + { + return NUMERIC.name() + "(" + precision + ", " + scale + ")"; + } +} diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/PrimitiveTypeInfo.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/PrimitiveTypeInfo.java new file mode 100644 index 000000000000..c329549d9f48 --- /dev/null +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/PrimitiveTypeInfo.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.bigquery.type; + +import com.google.cloud.bigquery.StandardSQLTypeName; + +import static io.trino.plugin.bigquery.type.TypeInfoUtils.getStandardSqlTypeNameFromTypeName; +import static java.util.Objects.requireNonNull; + +public sealed class PrimitiveTypeInfo + extends TypeInfo + permits BigDecimalTypeInfo, DecimalTypeInfo +{ + private final String typeName; + private final StandardSQLTypeName standardSqlTypeName; + + PrimitiveTypeInfo(String typeName) + { + this.typeName = requireNonNull(typeName, "typeName is null"); + this.standardSqlTypeName = getStandardSqlTypeNameFromTypeName(typeName); + } + + public StandardSQLTypeName getStandardSqlTypeName() + { + return standardSqlTypeName; + } + + @Override + public String toString() + { + return typeName; + } +} diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/TypeInfo.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/TypeInfo.java new file mode 100644 index 000000000000..91e53dc16ad0 --- /dev/null +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/TypeInfo.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.bigquery.type; + +public abstract sealed class TypeInfo + permits PrimitiveTypeInfo, ArrayTypeInfo +{ + protected TypeInfo() {} + + @Override + public abstract String toString(); +} diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/TypeInfoFactory.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/TypeInfoFactory.java new file mode 100644 index 000000000000..85d26cdd775c --- /dev/null +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/TypeInfoFactory.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.bigquery.type; + +import com.google.cloud.bigquery.StandardSQLTypeName; + +public final class TypeInfoFactory +{ + private TypeInfoFactory() {} + + public static PrimitiveTypeInfo getPrimitiveTypeInfo(StandardSQLTypeName typeEntry) + { + return new PrimitiveTypeInfo(typeEntry.name()); + } + + public static DecimalTypeInfo getDecimalTypeInfo(int precision, int scale) + { + return new DecimalTypeInfo(precision, scale); + } + + public static BigDecimalTypeInfo getBigDecimalTypeInfo(int precision, int scale) + { + return new BigDecimalTypeInfo(precision, scale); + } + + public static TypeInfo getArrayTypeInfo(TypeInfo elementTypeInfo) + { + return new ArrayTypeInfo(elementTypeInfo); + } +} diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/TypeInfoUtils.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/TypeInfoUtils.java new file mode 100644 index 000000000000..b02e846ec04a --- /dev/null +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/TypeInfoUtils.java @@ -0,0 +1,254 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.bigquery.type; + +import com.google.cloud.bigquery.StandardSQLTypeName; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +import static com.google.cloud.bigquery.StandardSQLTypeName.ARRAY; +import static com.google.cloud.bigquery.StandardSQLTypeName.STRUCT; +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.plugin.bigquery.type.TypeInfoFactory.getArrayTypeInfo; +import static io.trino.plugin.bigquery.type.TypeInfoFactory.getBigDecimalTypeInfo; +import static io.trino.plugin.bigquery.type.TypeInfoFactory.getDecimalTypeInfo; +import static io.trino.plugin.bigquery.type.TypeInfoFactory.getPrimitiveTypeInfo; +import static java.lang.Character.isLetterOrDigit; +import static java.lang.Integer.parseInt; +import static java.util.Objects.requireNonNull; + +public final class TypeInfoUtils +{ + private static final Map TYPES = new HashMap<>(); + + private TypeInfoUtils() {} + + static { + registerType(StandardSQLTypeName.BOOL); + registerType(StandardSQLTypeName.INT64); + registerType(StandardSQLTypeName.FLOAT64); + registerType(StandardSQLTypeName.NUMERIC); + registerType(StandardSQLTypeName.BIGNUMERIC); + registerType(StandardSQLTypeName.STRING); + registerType(StandardSQLTypeName.BYTES); + registerType(StandardSQLTypeName.DATE); + registerType(StandardSQLTypeName.DATETIME); + registerType(StandardSQLTypeName.TIME); + registerType(StandardSQLTypeName.TIMESTAMP); + registerType(StandardSQLTypeName.GEOGRAPHY); + registerType(StandardSQLTypeName.JSON); + } + + private static void registerType(StandardSQLTypeName entry) + { + TYPES.put(entry.name(), entry); + } + + public static StandardSQLTypeName getStandardSqlTypeNameFromTypeName(String typeName) + { + return TYPES.get(typeName); + } + + private static class TypeInfoParser + { + public record Token(int position, String text, boolean type) + { + public Token + { + requireNonNull(text, "text is null"); + } + + @Override + public String toString() + { + return "%s:%s".formatted(position, text); + } + } + + private static boolean isTypeChar(char c) + { + return isLetterOrDigit(c) || c == '_' || c == '-'; + } + + private static boolean isDigit(String string) + { + return string.chars().allMatch(Character::isDigit); + } + + private static List tokenize(String typeInfoString) + { + List tokens = new ArrayList<>(); + int begin = 0; + int end = 1; + while (end <= typeInfoString.length()) { + if (end == typeInfoString.length() || + !isTypeChar(typeInfoString.charAt(end - 1)) || + !isTypeChar(typeInfoString.charAt(end))) { + Token token = new Token( + begin, + typeInfoString.substring(begin, end), + isTypeChar(typeInfoString.charAt(begin))); + tokens.add(token); + begin = end; + } + end++; + } + return tokens; + } + + public TypeInfoParser(String typeInfoString) + { + this.typeInfoString = typeInfoString; + typeInfoTokens = tokenize(typeInfoString); + } + + private final String typeInfoString; + private final List typeInfoTokens; + private int index; + + public TypeInfo parseTypeInfo() + { + TypeInfo typeInfo; + index = 0; + typeInfo = parseType(); + if (index < typeInfoTokens.size()) { + throw new IllegalArgumentException("Error: unexpected character at the end of '%s'".formatted(typeInfoString)); + } + return typeInfo; + } + + private Token peek() + { + if (index < typeInfoTokens.size()) { + return typeInfoTokens.get(index); + } + return null; + } + + private Token expect(String item) + { + if (index >= typeInfoTokens.size()) { + throw new IllegalArgumentException("Error: %s expected at the end of '%s'".formatted(item, typeInfoString)); + } + Token token = typeInfoTokens.get(index); + + if (item.equals("type")) { + if (!ARRAY.name().equals(token.text()) && !STRUCT.name().equals(token.text()) && getStandardSqlTypeNameFromTypeName(token.text()) == null) { + throw new IllegalArgumentException("Error: '%s' expected at the position %s of '%s' but '%s' is found.".formatted(item, token.position(), typeInfoString, token.text())); + } + } + else if (item.equals("name")) { + if (!token.type()) { + throw new IllegalArgumentException("Error: '%s' expected at the position %s of '%s' but '%s' is found.".formatted(item, token.position(), typeInfoString, token.text())); + } + } + else if (!item.equals(token.text())) { + throw new IllegalArgumentException("Error: '%s' expected at the position %s of '%s' but '%s' is found.".formatted(item, token.position(), typeInfoString, token.text())); + } + + index++; + return token; + } + + private String[] parseParams() + { + List params = new LinkedList<>(); + + Token token = peek(); + if (token != null && token.text().equals("(")) { + expect("("); + token = peek(); + while ((token == null || !token.text().equals(")")) && index < typeInfoTokens.size()) { + Token name = typeInfoTokens.get(index); + if (isDigit(name.text())) { + params.add(name.text()); + } + token = name; + index++; + } + if (params.isEmpty()) { + throw new IllegalArgumentException("type parameters expected for type string " + typeInfoString); + } + } + + return params.toArray(new String[0]); + } + + private TypeInfo parseType() + { + Token token = expect("type"); + + StandardSQLTypeName typeEntry = getStandardSqlTypeNameFromTypeName(token.text()); + if (typeEntry != null) { + String[] params = parseParams(); + return switch (typeEntry) { + case STRING -> { + if (params.length != 0 && params.length != 1) { + throw new IllegalArgumentException("Type string only takes zero or one parameter, but %s is seen".formatted(params.length)); + } + checkArgument(params.length == 0 || parseInt(params[0]) >= 0, "invalid string length, must be equal or greater than zero"); + yield getPrimitiveTypeInfo(StandardSQLTypeName.STRING); + } + case NUMERIC -> { + if (params.length == 0) { + yield getDecimalTypeInfo(38, 9); + } + if (params.length == 1) { + yield getDecimalTypeInfo(parseInt(params[0]), 0); + } + if (params.length == 2) { + yield getDecimalTypeInfo(parseInt(params[0]), parseInt(params[1])); + } + throw new IllegalArgumentException("Type decimal only takes two parameters, but %s is seen".formatted(params.length)); + } + case BIGNUMERIC -> { + if (params.length == 0) { + yield getBigDecimalTypeInfo(76, 38); + } + if (params.length == 1) { + yield getBigDecimalTypeInfo(parseInt(params[0]), 0); + } + if (params.length == 2) { + yield getBigDecimalTypeInfo(parseInt(params[0]), parseInt(params[1])); + } + throw new IllegalArgumentException("Type decimal only takes two parameters, but %s is seen".formatted(params.length)); + } + default -> getPrimitiveTypeInfo(typeEntry); + }; + } + + if (ARRAY.name().equals(token.text())) { + expect("<"); + TypeInfo listElementType = parseType(); + expect(">"); + return getArrayTypeInfo(listElementType); + } + + if (STRUCT.name().equals(token.text())) { + throw new UnsupportedTypeException(STRUCT, "STRUCT type is not supported, because it can contain unquoted field names, containing spaces, type names, and characters like '>'."); + } + + throw new RuntimeException("Internal error parsing position %s of '%s'".formatted(token.position(), typeInfoString)); + } + } + + public static TypeInfo parseTypeString(String typeString) + { + return new TypeInfoParser(typeString).parseTypeInfo(); + } +} diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/UnsupportedTypeException.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/UnsupportedTypeException.java new file mode 100644 index 000000000000..31b30b8f90aa --- /dev/null +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/type/UnsupportedTypeException.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.bigquery.type; + +import com.google.cloud.bigquery.StandardSQLTypeName; + +/** + * Unsupported type exception occurs only for data types + * known to be not supported, like STRUCT, not all unknown tokens. + */ +public class UnsupportedTypeException + extends IllegalArgumentException +{ + private final StandardSQLTypeName typeName; + + public UnsupportedTypeException(StandardSQLTypeName typeName, String errorMessage) + { + super(errorMessage); + this.typeName = typeName; + } + + public StandardSQLTypeName getTypeName() + { + return typeName; + } +} diff --git a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/BaseBigQueryConnectorTest.java b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/BaseBigQueryConnectorTest.java index d2dd828f1999..8b926b26371b 100644 --- a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/BaseBigQueryConnectorTest.java +++ b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/BaseBigQueryConnectorTest.java @@ -52,6 +52,7 @@ import static io.trino.testing.assertions.Assert.assertEventually; import static java.lang.String.format; import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.stream.Collectors.joining; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.Assumptions.abort; @@ -80,18 +81,18 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) return switch (connectorBehavior) { case SUPPORTS_TRUNCATE -> true; case SUPPORTS_ADD_COLUMN, - SUPPORTS_CREATE_MATERIALIZED_VIEW, - SUPPORTS_CREATE_VIEW, - SUPPORTS_DEREFERENCE_PUSHDOWN, - SUPPORTS_MERGE, - SUPPORTS_NEGATIVE_DATE, - SUPPORTS_NOT_NULL_CONSTRAINT, - SUPPORTS_RENAME_COLUMN, - SUPPORTS_RENAME_SCHEMA, - SUPPORTS_RENAME_TABLE, - SUPPORTS_SET_COLUMN_TYPE, - SUPPORTS_TOPN_PUSHDOWN, - SUPPORTS_UPDATE -> false; + SUPPORTS_CREATE_MATERIALIZED_VIEW, + SUPPORTS_CREATE_VIEW, + SUPPORTS_DEREFERENCE_PUSHDOWN, + SUPPORTS_MERGE, + SUPPORTS_NEGATIVE_DATE, + SUPPORTS_NOT_NULL_CONSTRAINT, + SUPPORTS_RENAME_COLUMN, + SUPPORTS_RENAME_SCHEMA, + SUPPORTS_RENAME_TABLE, + SUPPORTS_SET_COLUMN_TYPE, + SUPPORTS_TOPN_PUSHDOWN, + SUPPORTS_UPDATE -> false; default -> super.hasBehavior(connectorBehavior); }; } @@ -368,6 +369,56 @@ public void testStreamCommentTableSpecialCharacter() } } + @Test + public void testStreamRelationColumns() + { + String schemaName = "test_columns" + randomNameSuffix(); + assertUpdate("CREATE SCHEMA " + schemaName); + try { + // Can't use testColumnNameDataProvider() here as it includes unsupported column names + List columnNames = List.of( + "_a", + "a-b", + "c_", + "c d", + "name-列", + "start&%=+:'<>#|end", + "type inside array name"); + String trinoColumns = columnNames.stream() + .map(column -> "\"" + column.replace("\"", "\"\"") + "\" varchar") + .collect(joining(", ")); + String bigQueryColumns = columnNames.stream() + .map(column -> "`" + column.replace("`", "\\`") + "` string") + .collect(joining(", ")); + assertUpdate("CREATE TABLE " + schemaName + ".trino_columns (" + trinoColumns + ", complex row(" + trinoColumns + "))"); + bigQuerySqlExecutor.execute("CREATE TABLE " + schemaName + ".bigquery_columns (" + bigQueryColumns + ", complex struct<" + bigQueryColumns + ">)"); + + // notice no predicate for table name to make sure BigQueryMetadata.streamRelationColumns() is used + assertQuery( + "SELECT table_name, column_name, data_type FROM information_schema.columns WHERE table_schema = '" + schemaName + "'", + "VALUES " + + "('trino_columns', '_a', 'varchar')," + + "('trino_columns', 'a-b', 'varchar')," + + "('trino_columns', 'c_', 'varchar')," + + "('trino_columns', 'c d', 'varchar')," + + "('trino_columns', 'name-列', 'varchar')," + + "('trino_columns', 'start&%=+:''<>#|end', 'varchar')," + + "('trino_columns', 'type inside array name', 'varchar')," + + "('trino_columns', 'complex', 'row(_a varchar, a-b varchar, c_ varchar, c d varchar, name-列 varchar, start&%=+:''<>#|end varchar, type inside array name varchar)')," + + "('bigquery_columns', '_a', 'varchar')," + + "('bigquery_columns', 'a-b', 'varchar')," + + "('bigquery_columns', 'c_', 'varchar')," + + "('bigquery_columns', 'c d', 'varchar')," + + "('bigquery_columns', 'name-列', 'varchar')," + + "('bigquery_columns', 'start&%=+:''<>#|end', 'varchar')," + + "('bigquery_columns', 'type inside array name', 'varchar')," + + "('bigquery_columns', 'complex', 'row(_a varchar, a-b varchar, c_ varchar, c d varchar, name-列 varchar, start&%=+:''<>#|end varchar, type inside array name varchar)')"); + } + finally { + assertUpdate("DROP SCHEMA " + schemaName + " CASCADE"); + } + } + @Test @Override // Override because the base test exceeds rate limits per a table public void testCommentColumn() @@ -646,6 +697,12 @@ public void testSkipUnsupportedType() " a bigint,\n" + " b bigint\n" + ")"); + // querying without predicates can use different metadata methods + String tableName = table.getName().split("\\.", 2)[1]; + assertThat(computeActual("SELECT table_name, column_name FROM information_schema.columns WHERE table_schema = 'test'").getMaterializedRows().stream() + .filter(row -> row.getField(0).equals(tableName)) + .map(row -> row.getField(1))) + .containsExactlyInAnyOrder("a", "b"); } } @@ -786,9 +843,9 @@ private void assertLabelForTable(String expectedView, QueryId queryId, String tr @Language("SQL") String checkForLabelQuery = """ - SELECT * FROM region-us.INFORMATION_SCHEMA.JOBS_BY_USER WHERE EXISTS( - SELECT * FROM UNNEST(labels) AS label WHERE label.key = 'trino_query' AND label.value = '%s' - )""".formatted(expectedLabel); + SELECT * FROM region-us.INFORMATION_SCHEMA.JOBS_BY_USER WHERE EXISTS( + SELECT * FROM UNNEST(labels) AS label WHERE label.key = 'trino_query' AND label.value = '%s' + )""".formatted(expectedLabel); assertEventually(() -> assertThat(bigQuerySqlExecutor.executeQuery(checkForLabelQuery).getValues()) .extracting(values -> values.get("query").getStringValue()) diff --git a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryType.java b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryType.java index 5ccacfd01a57..d604488fd9b5 100644 --- a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryType.java +++ b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/TestBigQueryType.java @@ -13,14 +13,41 @@ */ package io.trino.plugin.bigquery; +import com.google.cloud.bigquery.Field; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.StandardTableDefinition; +import com.google.cloud.bigquery.TableId; +import com.google.cloud.bigquery.TableInfo; +import io.trino.spi.connector.ColumnMetadata; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.RowType; import io.trino.spi.type.TimeZoneKey; +import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; +import java.util.List; +import java.util.Optional; + +import static com.google.cloud.bigquery.StandardSQLTypeName.INT64; +import static com.google.cloud.bigquery.StandardSQLTypeName.STRING; +import static com.google.cloud.bigquery.StandardSQLTypeName.STRUCT; +import static com.google.common.collect.MoreCollectors.onlyElement; import static io.airlift.slice.Slices.utf8Slice; import static io.airlift.slice.Slices.wrappedBuffer; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DateType.DATE; +import static io.trino.spi.type.DecimalType.createDecimalType; import static io.trino.spi.type.Decimals.encodeScaledValue; +import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.LongTimestampWithTimeZone.fromEpochSecondsAndFraction; import static io.trino.spi.type.TimeZoneKey.UTC_KEY; +import static io.trino.spi.type.TimestampType.TIMESTAMP_MICROS; +import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; +import static io.trino.type.JsonType.JSON; import static java.math.BigDecimal.ONE; import static org.assertj.core.api.Assertions.assertThat; @@ -30,6 +57,8 @@ @Deprecated public class TestBigQueryType { + private static final BigQueryTypeManager TYPE_MANAGER = new BigQueryTypeManager(TESTING_TYPE_MANAGER); + @Test public void testTimeToStringConverter() { @@ -84,4 +113,75 @@ public void testBytesToStringConverter() wrappedBuffer((byte) 1, (byte) 2, (byte) 3, (byte) 4))) .isEqualTo("FROM_BASE64('AQIDBA==')"); } + + @Test + void testConvertToTrinoType() + { + assertColumnType("BOOL", BOOLEAN); + assertColumnType("INT64", BIGINT); + assertColumnType("FLOAT64", DOUBLE); + assertColumnType("NUMERIC", createDecimalType(38, 9)); + assertColumnType("NUMERIC(1)", createDecimalType(1, 0)); + assertColumnType("NUMERIC(10, 5)", createDecimalType(10, 5)); + assertColumnType("NUMERIC(38, 9)", createDecimalType(38, 9)); + assertColumnType("BIGNUMERIC(1)", createDecimalType(1, 0)); + assertColumnType("BIGNUMERIC(10, 5)", createDecimalType(10, 5)); + assertColumnType("BIGNUMERIC(38, 38)", createDecimalType(38, 38)); + assertColumnType("STRING", VARCHAR); + assertColumnType("STRING(10)", VARCHAR); + assertColumnType("BYTES", VARBINARY); + assertColumnType("DATE", DATE); + assertColumnType("DATETIME", TIMESTAMP_MICROS); + assertColumnType("TIMESTAMP", TIMESTAMP_TZ_MICROS); + assertColumnType("GEOGRAPHY", VARCHAR); + assertColumnType("JSON", JSON); + assertColumnType("ARRAY", new ArrayType(BIGINT)); + } + + @Test + void testConvertToTrinoTypeStruct() + { + String structType = "STRUCT>"; + Field structField = Field.of( + "col", + STRUCT, + Field.of("x", INT64), + Field.newBuilder("y", STRING).setMode(Field.Mode.REPEATED).build()); + TableInfo tableInfo = TableInfo.of(TableId.of("fake", "table"), StandardTableDefinition.of(Schema.of(structField))); + + ColumnMetadata column = TYPE_MANAGER.convertToTrinoType( + List.of("col"), + List.of(structType), + () -> Optional.of(tableInfo)).stream() + .collect(onlyElement()); + RowType expected = RowType.from(List.of( + RowType.field("x", BIGINT), + RowType.field("y", new ArrayType(VARCHAR)))); + + // structs are not parsed, but fetched from the table info + assertThat(column.getType()).isEqualTo(expected); + // struct without table info is not recognized + assertThat(TYPE_MANAGER.convertToTrinoType(List.of("col"), List.of(structType))).isEmpty(); + } + + @Test + void testConvertToTrinoTypeUnsupported() + { + // unsupported types are ignored, this includes decimals with precision and/or scale out of range + assertThat(TYPE_MANAGER.convertToTrinoType(List.of("col"), List.of("NUMERIC(38, 38)"))).isEmpty(); + assertThat(TYPE_MANAGER.convertToTrinoType(List.of("col"), List.of("BIGNUMERIC"))).isEmpty(); + assertThat(TYPE_MANAGER.convertToTrinoType(List.of("col"), List.of("BIGNUMERIC(76, 38)"))).isEmpty(); + assertThat(TYPE_MANAGER.convertToTrinoType(List.of("col"), List.of("TIME"))).isEmpty(); + assertThat(TYPE_MANAGER.convertToTrinoType(List.of("col"), List.of("RANGE"))).isEmpty(); + assertThat(TYPE_MANAGER.convertToTrinoType(List.of("col"), List.of("RANGE"))).isEmpty(); + assertThat(TYPE_MANAGER.convertToTrinoType(List.of("col"), List.of("RANGE"))).isEmpty(); + assertThat(TYPE_MANAGER.convertToTrinoType(List.of("col"), List.of("INTERVAL"))).isEmpty(); + assertThat(TYPE_MANAGER.convertToTrinoType(List.of("col"), List.of("invalid-type"))).isEmpty(); + } + + private static void assertColumnType(String typeString, Type expected) + { + ColumnMetadata column = TYPE_MANAGER.convertToTrinoType(List.of("col"), List.of(typeString)).stream().collect(onlyElement()); + assertThat(column.getType()).isEqualTo(expected); + } } diff --git a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/type/TestTypeInfoUtils.java b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/type/TestTypeInfoUtils.java new file mode 100644 index 000000000000..b87e5cd62507 --- /dev/null +++ b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/type/TestTypeInfoUtils.java @@ -0,0 +1,190 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.bigquery.type; + +import org.junit.jupiter.api.Test; + +import static io.trino.plugin.bigquery.type.TypeInfoUtils.parseTypeString; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestTypeInfoUtils +{ + @Test + public void testBasicPrimitive() + { + assertTypeInfo("BOOL"); + assertTypeInfo("INT64"); + assertTypeInfo("FLOAT64"); + assertTypeInfo("STRING"); + assertTypeInfo("STRING(10)", "STRING"); + assertTypeInfo("BYTES"); + assertTypeInfo("DATE"); + assertTypeInfo("TIME"); + assertTypeInfo("DATETIME"); + assertTypeInfo("TIMESTAMP"); + assertTypeInfo("GEOGRAPHY"); + assertTypeInfo("JSON"); + } + + @Test + public void testNumeric() + { + assertTypeInfo("NUMERIC", "NUMERIC(38, 9)"); + assertTypeInfo("NUMERIC(1)", "NUMERIC(1, 0)"); + assertTypeInfo("NUMERIC(5)", "NUMERIC(5, 0)"); + assertTypeInfo("NUMERIC(29)", "NUMERIC(29, 0)"); + assertTypeInfo("NUMERIC(1, 1)", "NUMERIC(1, 1)"); + assertTypeInfo("NUMERIC(10, 5)", "NUMERIC(10, 5)"); + assertTypeInfo("NUMERIC(38, 9)", "NUMERIC(38, 9)"); + + assertThatThrownBy(() -> parseTypeString("NUMERIC(0)")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("invalid decimal precision: 0"); + + assertThatThrownBy(() -> parseTypeString("NUMERIC(39)")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("invalid decimal precision: 39"); + + assertThatThrownBy(() -> parseTypeString("NUMERIC(38,39)")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("invalid decimal scale: 39"); + + assertThatThrownBy(() -> parseTypeString("NUMERIC(4,5)")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("invalid decimal precision: 4 is lower than scale 5"); + } + + @Test + public void testBignumeric() + { + assertTypeInfo("BIGNUMERIC", "BIGNUMERIC(76, 38)"); + assertTypeInfo("BIGNUMERIC(1)", "BIGNUMERIC(1, 0)"); + assertTypeInfo("BIGNUMERIC(5)", "BIGNUMERIC(5, 0)"); + assertTypeInfo("BIGNUMERIC(38)", "BIGNUMERIC(38, 0)"); + assertTypeInfo("BIGNUMERIC(1, 1)", "BIGNUMERIC(1, 1)"); + assertTypeInfo("BIGNUMERIC(10, 5)", "BIGNUMERIC(10, 5)"); + assertTypeInfo("BIGNUMERIC(76, 38)", "BIGNUMERIC(76, 38)"); + + assertThatThrownBy(() -> parseTypeString("BIGNUMERIC(0)")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("invalid decimal precision: 0"); + + assertThatThrownBy(() -> parseTypeString("BIGNUMERIC(77)")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("invalid decimal precision: 77"); + + assertThatThrownBy(() -> parseTypeString("BIGNUMERIC(76,39)")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("invalid decimal scale: 39"); + + assertThatThrownBy(() -> parseTypeString("BIGNUMERIC(4,5)")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("invalid decimal precision: 4 is lower than scale 5"); + + assertThatThrownBy(() -> parseTypeString("BIGNUMERIC(77)")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("invalid decimal precision: 77"); + } + + @Test + public void testArray() + { + assertTypeInfo("ARRAY"); + } + + @Test + public void testInvalidTypes() + { + // incomplete types should not be recognized as different types + assertThatThrownBy(() -> parseTypeString("BOOLEAN")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Error: 'type' expected at the position 0 of 'BOOLEAN' but 'BOOLEAN' is found."); + assertThatThrownBy(() -> parseTypeString("INT")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Error: 'type' expected at the position 0 of 'INT' but 'INT' is found."); + assertThatThrownBy(() -> parseTypeString("TIMESTAMP WITH TIME ZONE")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Error: unexpected character at the end of 'TIMESTAMP WITH TIME ZONE'"); + assertThatThrownBy(() -> parseTypeString("ARRAY")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Error: < expected at the end of 'ARRAY'"); + assertThatThrownBy(() -> parseTypeString("STRUCT")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("STRUCT type is not supported, because it can contain unquoted field names, containing spaces, type names, and characters like '>'."); + assertThatThrownBy(() -> parseTypeString("STRUCT")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("STRUCT type is not supported, because it can contain unquoted field names, containing spaces, type names, and characters like '>'."); + assertThatThrownBy(() -> parseTypeString("STRUCT<# INT64>")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("STRUCT type is not supported, because it can contain unquoted field names, containing spaces, type names, and characters like '>'."); + + // leading and trailing whitespace is not permitted + assertThatThrownBy(() -> parseTypeString(" JSON")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Error: 'type' expected at the position 0 of ' JSON' but ' ' is found."); + assertThatThrownBy(() -> parseTypeString("JSON ")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Error: unexpected character at the end of 'JSON '"); + + // invalid type parameters should not cause out of bounds errors + assertThatThrownBy(() -> parseTypeString("STRING(")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("type parameters expected for type string STRING("); + assertThatThrownBy(() -> parseTypeString("STRING()")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("type parameters expected for type string STRING()"); + assertThatThrownBy(() -> parseTypeString("STRING(1, 2)")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Type string only takes zero or one parameter, but 2 is seen"); + assertThatThrownBy(() -> parseTypeString("NUMERIC(1, 2, 3)")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Type decimal only takes two parameters, but 3 is seen"); + assertThatThrownBy(() -> parseTypeString("ARRAY<")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Error: type expected at the end of 'ARRAY<'"); + assertThatThrownBy(() -> parseTypeString("ARRAY<>")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Error: 'type' expected at the position 6 of 'ARRAY<>' but '>' is found."); + + // handle special characters + assertThatThrownBy(() -> parseTypeString("()")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Error: 'type' expected at the position 0 of '()' but '(' is found."); + assertThatThrownBy(() -> parseTypeString("''")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Error: 'type' expected at the position 0 of '''' but ''' is found."); + assertThatThrownBy(() -> parseTypeString("łąka")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Error: 'type' expected at the position 0 of 'łąka' but 'łąka' is found."); + assertThatThrownBy(() -> parseTypeString("_")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Error: 'type' expected at the position 0 of '_' but '_' is found."); + assertThatThrownBy(() -> parseTypeString("#")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Error: 'type' expected at the position 0 of '#' but '#' is found."); + } + + private static void assertTypeInfo(String typeString) + { + assertThat(parseTypeString(typeString)) + .hasToString(typeString); + } + + private static void assertTypeInfo(String typeString, String toString) + { + assertThat(parseTypeString(typeString)) + .hasToString(toString); + } +}