diff --git a/core/trino-main/src/main/java/io/trino/connector/InternalMetadataProvider.java b/core/trino-main/src/main/java/io/trino/connector/InternalMetadataProvider.java index 2f76f9778429..5eeb0447bf74 100644 --- a/core/trino-main/src/main/java/io/trino/connector/InternalMetadataProvider.java +++ b/core/trino-main/src/main/java/io/trino/connector/InternalMetadataProvider.java @@ -13,6 +13,7 @@ */ package io.trino.connector; +import com.google.common.collect.ImmutableList; import io.trino.FullConnectorSession; import io.trino.Session; import io.trino.metadata.MaterializedViewDefinition; @@ -54,12 +55,12 @@ public Optional getRelationMetadata(ConnectorSession conne Optional materializedView = metadata.getMaterializedView(session, qualifiedName); if (materializedView.isPresent()) { - return Optional.of(new ConnectorTableSchema(tableName.getSchemaTableName(), toColumnSchema(materializedView.get().getColumns()))); + return Optional.of(new ConnectorTableSchema(tableName.getSchemaTableName(), toColumnSchema(materializedView.get().getColumns()), ImmutableList.of())); } Optional view = metadata.getView(session, qualifiedName); if (view.isPresent()) { - return Optional.of(new ConnectorTableSchema(tableName.getSchemaTableName(), toColumnSchema(view.get().getColumns()))); + return Optional.of(new ConnectorTableSchema(tableName.getSchemaTableName(), toColumnSchema(view.get().getColumns()), ImmutableList.of())); } Optional tableHandle = metadata.getTableHandle(session, qualifiedName); diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index fa88130920d7..22269b9ec317 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -211,6 +211,7 @@ public class Analysis private final Multiset rowFilterScopes = HashMultiset.create(); private final Map, List> rowFilters = new LinkedHashMap<>(); + private final Map, List> checkConstraints = new LinkedHashMap<>(); private final Multiset columnMaskScopes = HashMultiset.create(); private final Map, Map>> columnMasks = new LinkedHashMap<>(); @@ -1070,11 +1071,22 @@ public void addRowFilter(Table table, Expression filter) .add(filter); } + public void addCheckConstraints(Table table, Expression filter) + { + checkConstraints.computeIfAbsent(NodeRef.of(table), node -> new ArrayList<>()) + .add(filter); + } + public List getRowFilters(Table node) { return rowFilters.getOrDefault(NodeRef.of(node), ImmutableList.of()); } + public List getCheckConstraints(Table node) + { + return checkConstraints.getOrDefault(NodeRef.of(node), ImmutableList.of()); + } + public boolean hasColumnMask(QualifiedObjectName table, String column, String identity) { return columnMaskScopes.contains(new ColumnMaskScopeEntry(table, column, identity)); diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 3b5ec1b24c5d..268f5ce9a741 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -535,6 +535,7 @@ protected Scope visitInsert(Insert insert, Optional scope) List columns = tableSchema.getColumns().stream() .filter(column -> !column.isHidden()) .collect(toImmutableList()); + List checkConstraints = tableSchema.getTableSchema().getCheckConstraints(); for (ColumnSchema column : columns) { if (!accessControl.getColumnMasks(session.toSecurityContext(), targetTable, column.getName(), column.getType()).isEmpty()) { @@ -544,7 +545,7 @@ protected Scope visitInsert(Insert insert, Optional scope) Map columnHandles = metadata.getColumnHandles(session, targetTableHandle.get()); List tableFields = analyzeTableOutputFields(insert.getTable(), targetTable, tableSchema, columnHandles); - analyzeFiltersAndMasks(insert.getTable(), targetTable, targetTableHandle, tableFields, session.getIdentity().getUser()); + analyzeFiltersAndMasks(insert.getTable(), targetTable, targetTableHandle, new RelationType(tableFields), session.getIdentity().getUser(), checkConstraints); List tableColumns = columns.stream() .map(ColumnSchema::getName) @@ -791,7 +792,7 @@ protected Scope visitDelete(Delete node, Optional scope) analysis.setUpdateType("DELETE"); analysis.setUpdateTarget(tableName, Optional.of(table), Optional.empty()); - analyzeFiltersAndMasks(table, tableName, Optional.of(handle), analysis.getScope(table).getRelationType(), session.getIdentity().getUser()); + analyzeFiltersAndMasks(table, tableName, Optional.of(handle), analysis.getScope(table).getRelationType(), session.getIdentity().getUser(), ImmutableList.of()); return createAndAssignScope(node, scope, Field.newUnqualified("rows", BIGINT)); } @@ -2145,10 +2146,10 @@ private void checkStorageTableNotRedirected(QualifiedObjectName source) private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Optional tableHandle, List fields, String authorization) { - analyzeFiltersAndMasks(table, name, tableHandle, new RelationType(fields), authorization); + analyzeFiltersAndMasks(table, name, tableHandle, new RelationType(fields), authorization, ImmutableList.of()); } - private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Optional tableHandle, RelationType relationType, String authorization) + private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Optional tableHandle, RelationType relationType, String authorization, List checkConstraints) { Scope accessControlScope = Scope.builder() .withRelationType(RelationId.anonymous(), relationType) @@ -2165,8 +2166,15 @@ private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Optio } } - accessControl.getRowFilters(session.toSecurityContext(), name) - .forEach(filter -> analyzeRowFilter(session.getIdentity().getUser(), table, name, accessControlScope, filter)); + for (ViewExpression filter : accessControl.getRowFilters(session.toSecurityContext(), name)) { + Expression expression = analyzeRowFilter(session.getIdentity().getUser(), table, name, accessControlScope, filter); + analysis.addRowFilter(table, expression); + } + for (String checkConstraint : checkConstraints) { + ViewExpression filter = new ViewExpression(session.getIdentity().getUser(), Optional.of(name.getCatalogName()), Optional.of(name.getSchemaName()), checkConstraint); + Expression expression = analyzeRowFilter(session.getIdentity().getUser(), table, name, accessControlScope, filter); + analysis.addCheckConstraints(table, expression); + } analysis.registerTable(table, tableHandle, name, authorization, accessControlScope); } @@ -4491,7 +4499,7 @@ private ExpressionAnalysis analyzeExpression(Expression expression, Scope scope, correlationSupport); } - private void analyzeRowFilter(String currentIdentity, Table table, QualifiedObjectName name, Scope scope, ViewExpression filter) + private Expression analyzeRowFilter(String currentIdentity, Table table, QualifiedObjectName name, Scope scope, ViewExpression filter) { if (analysis.hasRowFilter(name, currentIdentity)) { throw new TrinoException(INVALID_ROW_FILTER, extractLocation(table), format("Row filter for '%s' is recursive", name), null); @@ -4542,7 +4550,7 @@ private void analyzeRowFilter(String currentIdentity, Table table, QualifiedObje analysis.addCoercion(expression, BOOLEAN, coercion.isTypeOnlyCoercion(actualType, BOOLEAN)); } - analysis.addRowFilter(table, expression); + return expression; } private void analyzeColumnMask(String currentIdentity, Table table, QualifiedObjectName tableName, Field field, Scope scope, ViewExpression mask) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index 953ec8e4c3e7..57c94b027203 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -486,18 +486,8 @@ private RelationPlan getInsertPlan( plan = new RelationPlan(projectNode, scope, projectNode.getOutputSymbols(), Optional.empty()); - plan = planner.addRowFilters( - table, - plan, - failIfPredicateIsNotMet(metadata, session, PERMISSION_DENIED, AccessDeniedException.PREFIX + "Cannot insert row that does not match to a row filter"), - node -> { - Scope accessControlScope = analysis.getAccessControlScope(table); - // hidden fields are not accessible in insert - return Scope.builder() - .like(accessControlScope) - .withRelationType(accessControlScope.getRelationId(), accessControlScope.getRelationType().withOnlyVisibleFields()) - .build(); - }); + plan = addRowFilters(analysis, planner, table, plan, failIfPredicateIsNotMet(metadata, session, PERMISSION_DENIED, AccessDeniedException.PREFIX + "Cannot insert row that does not match to a row filter"), analysis.getRowFilters(table)); + plan = addRowFilters(analysis, planner, table, plan, failIfCheckConstraintIsNotMet(metadata, session), analysis.getCheckConstraints(table)); List insertedTableColumnNames = insertedColumns.stream() .map(ColumnMetadata::getName) @@ -530,6 +520,23 @@ private RelationPlan getInsertPlan( statisticsMetadata); } + private RelationPlan addRowFilters(Analysis analysis, RelationPlanner planner, Table table, RelationPlan plan, Function predicateTransformation, List filters) + { + return planner.addRowFilters( + filters, + table, + plan, + predicateTransformation, + node -> { + Scope accessControlScope = analysis.getAccessControlScope(table); + // hidden fields are not accessible in insert + return Scope.builder() + .like(accessControlScope) + .withRelationType(accessControlScope.getRelationId(), accessControlScope.getRelationType().withOnlyVisibleFields()) + .build(); + }); + } + private Expression createNullNotAllowedFailExpression(String columnName, Type type) { return new Cast(failFunction(metadata, session, CONSTRAINT_VIOLATION, format( @@ -543,6 +550,11 @@ private static Function failIfPredicateIsNotMet(Metadata return predicate -> new IfExpression(predicate, TRUE_LITERAL, new Cast(fail, toSqlType(BOOLEAN))); } + private static Function failIfCheckConstraintIsNotMet(Metadata metadata, Session session) + { + return predicate -> new IfExpression(predicate, TRUE_LITERAL, new Cast(failFunction(metadata, session, CONSTRAINT_VIOLATION, "Cannot insert row that does not match to a check constraint: " + predicate), toSqlType(BOOLEAN))); + } + public static FunctionCall failFunction(Metadata metadata, Session session, ErrorCodeSupplier errorCode, String errorMessage) { return FunctionCallBuilder.resolve(session, metadata) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 19d2d37b0d40..027bd5cc60df 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -267,7 +267,11 @@ public RelationPlan addRowFilters(Table node, RelationPlan plan, Function predicateTransformation, Function accessControlScope) { List filters = analysis.getRowFilters(node); + return addRowFilters(filters, node, plan, predicateTransformation, accessControlScope); + } + public RelationPlan addRowFilters(List filters, Table node, RelationPlan plan, Function predicateTransformation, Function accessControlScope) + { if (filters.isEmpty()) { return plan; } diff --git a/core/trino-main/src/test/java/io/trino/connector/MockConnector.java b/core/trino-main/src/test/java/io/trino/connector/MockConnector.java index 2fc1722d21f8..c6e382947f34 100644 --- a/core/trino-main/src/test/java/io/trino/connector/MockConnector.java +++ b/core/trino-main/src/test/java/io/trino/connector/MockConnector.java @@ -132,6 +132,7 @@ public class MockConnector private final BiFunction> refreshMaterializedView; private final BiFunction getTableHandle; private final Function> getColumns; + private final Function> checkConstraints; private final MockConnectorFactory.ApplyProjection applyProjection; private final MockConnectorFactory.ApplyAggregation applyAggregation; private final MockConnectorFactory.ApplyJoin applyJoin; @@ -170,6 +171,7 @@ public class MockConnector BiFunction> refreshMaterializedView, BiFunction getTableHandle, Function> getColumns, + Function> checkConstraints, ApplyProjection applyProjection, ApplyAggregation applyAggregation, ApplyJoin applyJoin, @@ -206,6 +208,7 @@ public class MockConnector this.refreshMaterializedView = requireNonNull(refreshMaterializedView, "refreshMaterializedView is null"); this.getTableHandle = requireNonNull(getTableHandle, "getTableHandle is null"); this.getColumns = requireNonNull(getColumns, "getColumns is null"); + this.checkConstraints = requireNonNull(checkConstraints, "checkConstraints is null"); this.applyProjection = requireNonNull(applyProjection, "applyProjection is null"); this.applyAggregation = requireNonNull(applyAggregation, "applyAggregation is null"); this.applyJoin = requireNonNull(applyJoin, "applyJoin is null"); @@ -441,7 +444,12 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle) { MockConnectorTableHandle table = (MockConnectorTableHandle) tableHandle; - return new ConnectorTableMetadata(table.getTableName(), getColumns.apply(table.getTableName())); + return new ConnectorTableMetadata( + table.getTableName(), + getColumns.apply(table.getTableName()), + ImmutableMap.of(), + Optional.empty(), + checkConstraints.apply(table.getTableName())); } @Override diff --git a/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java b/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java index c96a0387292a..db7c14cf828e 100644 --- a/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java +++ b/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java @@ -90,6 +90,7 @@ public class MockConnectorFactory private final BiFunction> refreshMaterializedView; private final BiFunction getTableHandle; private final Function> getColumns; + private final Function> checkConstraints; private final ApplyProjection applyProjection; private final ApplyAggregation applyAggregation; private final ApplyJoin applyJoin; @@ -130,6 +131,7 @@ private MockConnectorFactory( BiFunction> refreshMaterializedView, BiFunction getTableHandle, Function> getColumns, + Function> checkConstraints, ApplyProjection applyProjection, ApplyAggregation applyAggregation, ApplyJoin applyJoin, @@ -167,6 +169,7 @@ private MockConnectorFactory( this.refreshMaterializedView = requireNonNull(refreshMaterializedView, "refreshMaterializedView is null"); this.getTableHandle = requireNonNull(getTableHandle, "getTableHandle is null"); this.getColumns = requireNonNull(getColumns, "getColumns is null"); + this.checkConstraints = requireNonNull(checkConstraints, "checkConstraints is null"); this.applyProjection = requireNonNull(applyProjection, "applyProjection is null"); this.applyAggregation = requireNonNull(applyAggregation, "applyAggregation is null"); this.applyJoin = requireNonNull(applyJoin, "applyJoin is null"); @@ -214,6 +217,7 @@ public Connector create(String catalogName, Map config, Connecto refreshMaterializedView, getTableHandle, getColumns, + checkConstraints, applyProjection, applyAggregation, applyJoin, @@ -338,6 +342,7 @@ public static final class Builder private BiFunction> refreshMaterializedView = (session, viewName) -> CompletableFuture.completedFuture(null); private BiFunction getTableHandle = defaultGetTableHandle(); private Function> getColumns = defaultGetColumns(); + private Function> checkConstraints = (schemaTableName -> ImmutableList.of()); private ApplyProjection applyProjection = (session, handle, projections, assignments) -> Optional.empty(); private ApplyAggregation applyAggregation = (session, handle, aggregates, assignments, groupingSets) -> Optional.empty(); private ApplyJoin applyJoin = (session, joinType, left, right, joinConditions, leftAssignments, rightAssignments) -> Optional.empty(); @@ -451,6 +456,12 @@ public Builder withGetColumns(Function> ge return this; } + public Builder withCheckConstraints(Function> checkConstraints) + { + this.checkConstraints = requireNonNull(checkConstraints, "checkConstraints is null"); + return this; + } + public Builder withApplyProjection(ApplyProjection applyProjection) { this.applyProjection = applyProjection; @@ -647,6 +658,7 @@ public MockConnectorFactory build() refreshMaterializedView, getTableHandle, getColumns, + checkConstraints, applyProjection, applyAggregation, applyJoin, diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestCheckConstraint.java b/core/trino-main/src/test/java/io/trino/sql/query/TestCheckConstraint.java new file mode 100644 index 000000000000..2f4a98973a3b --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestCheckConstraint.java @@ -0,0 +1,134 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.query; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.Session; +import io.trino.connector.MockConnectorFactory; +import io.trino.plugin.tpch.TpchConnectorFactory; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.security.Identity; +import io.trino.testing.LocalQueryRunner; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import static io.trino.connector.MockConnectorEntities.TPCH_NATION_DATA; +import static io.trino.connector.MockConnectorEntities.TPCH_NATION_SCHEMA; +import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; + +@TestInstance(PER_CLASS) +@Execution(SAME_THREAD) +public class TestCheckConstraint +{ + private static final String LOCAL_CATALOG = "local"; + private static final String MOCK_CATALOG = "mock"; + private static final String USER = "user"; + + private static final Session SESSION = testSessionBuilder() + .setCatalog(LOCAL_CATALOG) + .setSchema(TINY_SCHEMA_NAME) + .setIdentity(Identity.forUser(USER).build()) + .build(); + + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + LocalQueryRunner runner = LocalQueryRunner.builder(SESSION).build(); + + runner.createCatalog(LOCAL_CATALOG, new TpchConnectorFactory(1), ImmutableMap.of()); + + MockConnectorFactory mock = MockConnectorFactory.builder() + .withGetColumns(schemaTableName -> { + if (schemaTableName.equals(new SchemaTableName("tiny", "nation"))) { + return TPCH_NATION_SCHEMA; + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_with_invalid_check_constraint"))) { + return TPCH_NATION_SCHEMA; + } + throw new UnsupportedOperationException(); + }) + .withCheckConstraints(schemaTableName -> { + if (schemaTableName.equals(new SchemaTableName("tiny", "nation"))) { + return ImmutableList.of("nationkey > 100"); + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_with_invalid_check_constraint"))) { + return ImmutableList.of("invalid_function(nationkey) > 100"); + } + throw new UnsupportedOperationException(); + }) + .withData(schemaTableName -> { + if (schemaTableName.equals(new SchemaTableName("tiny", "nation"))) { + return TPCH_NATION_DATA; + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_with_invalid_check_constraint"))) { + return TPCH_NATION_DATA; + } + throw new UnsupportedOperationException(); + }) + .build(); + + runner.createCatalog(MOCK_CATALOG, mock, ImmutableMap.of()); + + assertions = new QueryAssertions(runner); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } + + @Test + public void testInsert() + { + assertions.query("INSERT INTO mock.tiny.nation VALUES (101, 'POLAND', 0, 'No comment')") + .assertThat() + .skippingTypesCheck() + .matches("SELECT BIGINT '1'"); + + // Outside allowed row filter + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation VALUES (26, 'POLAND', 0, 'No comment')")) + .hasMessage("Cannot insert row that does not match to a check constraint: (\"nationkey\" > CAST(100 AS bigint))"); + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation VALUES " + + "(26, 'POLAND', 0, 'No comment')," + + "(27, 'HOLLAND', 0, 'A comment')")) + .hasMessage("Cannot insert row that does not match to a check constraint: (\"nationkey\" > CAST(100 AS bigint))"); + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation VALUES " + + "(26, 'POLAND', 0, 'No comment')," + + "(27, 'HOLLAND', 0, 'A comment')")) + .hasMessage("Cannot insert row that does not match to a check constraint: (\"nationkey\" > CAST(100 AS bigint))"); + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation(nationkey) VALUES (null)")) + .hasMessage("Cannot insert row that does not match to a check constraint: (\"nationkey\" > CAST(100 AS bigint))"); + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation(regionkey) VALUES (0)")) + .hasMessage("Cannot insert row that does not match to a check constraint: (\"nationkey\" > CAST(100 AS bigint))"); + } + + @Test + public void testInsertUnsupportedConstraint() + { + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation_with_invalid_check_constraint VALUES (101, 'POLAND', 0, 'No comment')")) + .hasMessageContaining("Function 'invalid_function' not registered"); + } +} diff --git a/core/trino-spi/pom.xml b/core/trino-spi/pom.xml index 15c74c6fae1d..f2d4f8c375e0 100644 --- a/core/trino-spi/pom.xml +++ b/core/trino-spi/pom.xml @@ -181,6 +181,11 @@ + + java.method.numberOfParametersChanged + method void io.trino.spi.connector.ConnectorTableSchema::<init>(io.trino.spi.connector.SchemaTableName, java.util.List<io.trino.spi.connector.ColumnSchema>) + method void io.trino.spi.connector.ConnectorTableSchema::<init>(io.trino.spi.connector.SchemaTableName, java.util.List<io.trino.spi.connector.ColumnSchema>, java.util.List<java.lang.String>) + java.method.numberOfParametersChanged method void io.trino.spi.eventlistener.QueryStatistics::<init>(java.time.Duration, java.time.Duration, java.time.Duration, java.time.Duration, java.util.Optional<java.time.Duration>, java.util.Optional<java.time.Duration>, java.util.Optional<java.time.Duration>, java.util.Optional<java.time.Duration>, java.util.Optional<java.time.Duration>, java.util.Optional<java.time.Duration>, java.util.Optional<java.time.Duration>, java.util.Optional<java.time.Duration>, java.util.Optional<java.time.Duration>, java.util.Optional<java.time.Duration>, long, long, long, long, long, long, long, long, long, long, long, long, long, long, long, double, double, java.util.List<io.trino.spi.eventlistener.StageGcStatistics>, int, boolean, java.util.List<io.trino.spi.eventlistener.StageCpuDistribution>, java.util.List<java.util.Optional<io.trino.spi.metrics.Distribution<?>>>, java.util.List<java.lang.String>, java.util.Optional<java.lang.String>) diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableMetadata.java index cb1963095098..d9116829bb3f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableMetadata.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableMetadata.java @@ -19,6 +19,7 @@ import java.util.Map; import java.util.Optional; +import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toUnmodifiableList; @@ -29,6 +30,7 @@ public class ConnectorTableMetadata private final Optional comment; private final List columns; private final Map properties; + private final List checkConstraints; public ConnectorTableMetadata(SchemaTableName table, List columns) { @@ -37,19 +39,26 @@ public ConnectorTableMetadata(SchemaTableName table, List column public ConnectorTableMetadata(SchemaTableName table, List columns, Map properties) { - this(table, columns, properties, Optional.empty()); + this(table, columns, properties, Optional.empty(), emptyList()); } public ConnectorTableMetadata(SchemaTableName table, List columns, Map properties, Optional comment) + { + this(table, columns, properties, comment, emptyList()); + } + + public ConnectorTableMetadata(SchemaTableName table, List columns, Map properties, Optional comment, List checkConstraints) { requireNonNull(table, "table is null"); requireNonNull(columns, "columns is null"); requireNonNull(comment, "comment is null"); + requireNonNull(checkConstraints, "checkConstraints is null"); this.table = table; this.columns = List.copyOf(columns); this.properties = Collections.unmodifiableMap(new LinkedHashMap<>(properties)); this.comment = comment; + this.checkConstraints = List.copyOf(checkConstraints); } public SchemaTableName getTable() @@ -72,13 +81,22 @@ public Optional getComment() return comment; } + /** + * @return List of string representation of a Trino SQL scalar expression that can refer to table columns by name and produces a result coercible to boolean + */ + public List getCheckConstraints() + { + return checkConstraints; + } + public ConnectorTableSchema getTableSchema() { return new ConnectorTableSchema( table, columns.stream() .map(ColumnMetadata::getColumnSchema) - .collect(toUnmodifiableList())); + .collect(toUnmodifiableList()), + checkConstraints); } @Override @@ -89,6 +107,7 @@ public String toString() sb.append(", columns=").append(columns); sb.append(", properties=").append(properties); comment.ifPresent(value -> sb.append(", comment='").append(value).append("'")); + sb.append(", checkConstraints=").append(checkConstraints); sb.append('}'); return sb.toString(); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableSchema.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableSchema.java index afa9694ecb66..baff7d0f568f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableSchema.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableSchema.java @@ -21,14 +21,17 @@ public class ConnectorTableSchema { private final SchemaTableName table; private final List columns; + private final List checkConstraints; - public ConnectorTableSchema(SchemaTableName table, List columns) + public ConnectorTableSchema(SchemaTableName table, List columns, List checkConstraints) { requireNonNull(table, "table is null"); requireNonNull(columns, "columns is null"); + requireNonNull(checkConstraints, "checkConstraints is null"); this.table = table; this.columns = List.copyOf(columns); + this.checkConstraints = List.copyOf(checkConstraints); } public SchemaTableName getTable() @@ -41,12 +44,18 @@ public List getColumns() return columns; } + public List getCheckConstraints() + { + return checkConstraints; + } + @Override public String toString() { return new StringBuilder("ConnectorTableSchema{") .append("table=").append(table) .append(", columns=").append(columns) + .append(", checkConstraints=").append(checkConstraints) .append('}') .toString(); } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java index 4a214795de95..7216bffb62e9 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java @@ -626,7 +626,8 @@ public ConnectorTableSchema getTableSchema(ConnectorSession session, ConnectorTa getSchemaTableName(handle), jdbcClient.getColumns(session, handle).stream() .map(JdbcColumnHandle::getColumnSchema) - .collect(toImmutableList())); + .collect(toImmutableList()), + ImmutableList.of()); } @Override diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java index 07733d89681c..e7da19a57ecd 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java @@ -446,6 +446,7 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect List columns = getColumns(tableHandle.getMetadataEntry()).stream() .map(column -> getColumnMetadata(column, columnComments.get(column.getName()), columnsNullability.getOrDefault(column.getName(), true))) .collect(toImmutableList()); + List checkConstraints = getCheckConstraints(tableHandle.getMetadataEntry()).values().stream().collect(toImmutableList()); ImmutableMap.Builder properties = ImmutableMap.builder() .put(LOCATION_PROPERTY, location) @@ -458,7 +459,8 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect tableHandle.getSchemaTableName(), columns, properties.buildOrThrow(), - Optional.ofNullable(tableHandle.getMetadataEntry().getDescription())); + Optional.ofNullable(tableHandle.getMetadataEntry().getDescription()), + checkConstraints); } @Override @@ -1263,9 +1265,6 @@ public ConnectorInsertTableHandle beginInsert(ConnectorSession session, Connecto if (!columnInvariants.isEmpty()) { throw new TrinoException(NOT_SUPPORTED, "Inserts are not supported for tables with delta invariants"); } - if (!getCheckConstraints(table.getMetadataEntry()).isEmpty()) { - throw new TrinoException(NOT_SUPPORTED, "Writing to tables with CHECK constraints is not supported"); - } checkUnsupportedGeneratedColumns(table.getMetadataEntry()); checkSupportedWriterVersion(session, table.getSchemaTableName()); diff --git a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java index cfc3beb2b1f0..69b5aa065b5e 100644 --- a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java +++ b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java @@ -118,7 +118,8 @@ public ConnectorTableSchema getTableSchema(ConnectorSession session, ConnectorTa getSchemaTableName(handle), getColumnMetadata(session, handle).stream() .map(ColumnMetadata::getColumnSchema) - .collect(toImmutableList())); + .collect(toImmutableList()), + ImmutableList.of()); } @Override diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeDatabricksInsertCompatibility.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeDatabricksInsertCompatibility.java index 11e277c5990e..cbf230c90489 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeDatabricksInsertCompatibility.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeDatabricksInsertCompatibility.java @@ -29,6 +29,7 @@ import static io.trino.tempto.assertions.QueryAssert.assertThat; import static io.trino.tests.product.TestGroups.DELTA_LAKE_DATABRICKS; import static io.trino.tests.product.TestGroups.DELTA_LAKE_EXCLUDE_73; +import static io.trino.tests.product.TestGroups.DELTA_LAKE_EXCLUDE_91; import static io.trino.tests.product.TestGroups.DELTA_LAKE_OSS; import static io.trino.tests.product.TestGroups.PROFILE_SPECIFIC_TESTS; import static io.trino.tests.product.deltalake.util.DeltaLakeTestUtils.DATABRICKS_104_RUNTIME_VERSION; @@ -318,7 +319,7 @@ public void testDeleteCompatibility() @Test(groups = {DELTA_LAKE_DATABRICKS, DELTA_LAKE_OSS, DELTA_LAKE_EXCLUDE_73, PROFILE_SPECIFIC_TESTS}) public void testCheckConstraintsCompatibility() { - // CHECK constraint is not supported by Trino + // CHECK constraint is supported only in INSERT statement by Trino String tableName = "test_check_constraint_not_supported_" + randomTableSuffix(); onDelta().executeQuery("CREATE TABLE default." + tableName + @@ -332,13 +333,47 @@ public void testCheckConstraintsCompatibility() assertThat(onTrino().executeQuery("SELECT id, a_number FROM " + tableName)) .containsOnly(row(1, 1)); + onTrino().executeQuery("INSERT INTO delta.default." + tableName + " VALUES (2, 2)"); + assertThat(onTrino().executeQuery("SELECT id, a_number FROM " + tableName)) + .containsOnly(row(1, 1), row(2, 2)); - assertQueryFailure(() -> onTrino().executeQuery("INSERT INTO delta.default." + tableName + " VALUES (2, 2)")) - .hasMessageContaining("Writing to tables with CHECK constraints is not supported"); + assertQueryFailure(() -> onTrino().executeQuery("INSERT INTO delta.default." + tableName + " VALUES (100, 100)")) + .hasMessageContaining("Cannot insert row that does not match to a check constraint"); assertQueryFailure(() -> onTrino().executeQuery("DELETE FROM delta.default." + tableName + " WHERE a_number = 1")) .hasMessageContaining("Writing to tables with CHECK constraints is not supported"); assertQueryFailure(() -> onTrino().executeQuery("UPDATE delta.default." + tableName + " SET a_number = 10 WHERE id = 1")) .hasMessageContaining("Writing to tables with CHECK constraints is not supported"); + assertQueryFailure(() -> onTrino().executeQuery("MERGE INTO delta.default." + tableName + " t USING delta.default." + tableName + " s " + + "ON (t.id = s.id) WHEN MATCHED THEN UPDATE SET a_number = 42")) + .hasMessageContaining("Writing to tables with CHECK constraints is not supported"); + + assertThat(onTrino().executeQuery("SELECT id, a_number FROM " + tableName)) + .containsOnly(row(1, 1), row(2, 2)); + } + finally { + onDelta().executeQuery("DROP TABLE default." + tableName); + } + } + + @Test(groups = {DELTA_LAKE_DATABRICKS, DELTA_LAKE_OSS, DELTA_LAKE_EXCLUDE_73, DELTA_LAKE_EXCLUDE_91, PROFILE_SPECIFIC_TESTS}) + public void testUnsupportedCheckConstraintExpression() + { + String tableName = "test_unsupported_check_constraint_expression_" + randomTableSuffix(); + + onDelta().executeQuery("CREATE TABLE default." + tableName + + "(id INT, a_number INT) " + + "USING DELTA " + + "LOCATION 's3://" + bucketName + "/databricks-compatibility-test-" + tableName + "'"); + // Use unsupported try_add function in CHECK constraint + onDelta().executeQuery("ALTER TABLE default." + tableName + " ADD CONSTRAINT id_constraint CHECK (try_add(id, 1) < 100)"); + + try { + onDelta().executeQuery("INSERT INTO default." + tableName + " (id, a_number) VALUES (1, 1)"); + assertThat(onTrino().executeQuery("SELECT id, a_number FROM " + tableName)) + .containsOnly(row(1, 1)); + + assertQueryFailure(() -> onTrino().executeQuery("INSERT INTO delta.default." + tableName + " VALUES (2, 2)")) + .hasMessageContaining("Function 'try_add' not registered"); assertThat(onTrino().executeQuery("SELECT id, a_number FROM " + tableName)) .containsOnly(row(1, 1)); @@ -439,31 +474,6 @@ public Object[][] compressionCodecs() }; } - @Test(groups = {DELTA_LAKE_OSS, DELTA_LAKE_DATABRICKS, DELTA_LAKE_EXCLUDE_73, PROFILE_SPECIFIC_TESTS}) - public void testWritesToTableWithCheckConstraintFails() - { - String tableName = "test_writes_into_table_with_check_constraint_" + randomTableSuffix(); - try { - onDelta().executeQuery("CREATE TABLE default." + tableName + " (a INT, b INT) " + - "USING DELTA " + - "LOCATION 's3://" + bucketName + "/databricks-compatibility-test-" + tableName + "'"); - onDelta().executeQuery("ALTER TABLE default." + tableName + " ADD CONSTRAINT aIsPositive CHECK (a > 0)"); - - assertQueryFailure(() -> onTrino().executeQuery("INSERT INTO delta.default." + tableName + " VALUES (1, 2)")) - .hasMessageContaining("Writing to tables with CHECK constraints is not supported"); - assertQueryFailure(() -> onTrino().executeQuery("UPDATE delta.default." + tableName + " SET a = 3 WHERE b = 3")) - .hasMessageContaining("Writing to tables with CHECK constraints is not supported"); - assertQueryFailure(() -> onTrino().executeQuery("DELETE FROM delta.default." + tableName + " WHERE a = 3")) - .hasMessageContaining("Writing to tables with CHECK constraints is not supported"); - assertQueryFailure(() -> onTrino().executeQuery("MERGE INTO delta.default." + tableName + " t USING delta.default." + tableName + " s " + - "ON (t.a = s.a) WHEN MATCHED THEN UPDATE SET b = 42")) - .hasMessageContaining("Writing to tables with CHECK constraints is not supported"); - } - finally { - onDelta().executeQuery("DROP TABLE IF EXISTS default." + tableName); - } - } - @Test(groups = {DELTA_LAKE_OSS, DELTA_LAKE_DATABRICKS, DELTA_LAKE_EXCLUDE_73, PROFILE_SPECIFIC_TESTS}) public void testMetadataOperationsRetainCheckConstraints() { @@ -477,9 +487,6 @@ public void testMetadataOperationsRetainCheckConstraints() onTrino().executeQuery("ALTER TABLE delta.default." + tableName + " ADD COLUMN c INT"); onTrino().executeQuery("COMMENT ON COLUMN delta.default." + tableName + ".c IS 'example column comment'"); onTrino().executeQuery("COMMENT ON TABLE delta.default." + tableName + " IS 'example table comment'"); - - assertQueryFailure(() -> onTrino().executeQuery("INSERT INTO delta.default." + tableName + " VALUES (1, 2, 3)")) - .hasMessageContaining("Writing to tables with CHECK constraints is not supported"); } finally { onDelta().executeQuery("DROP TABLE IF EXISTS default." + tableName);