diff --git a/core/trino-main/src/main/java/io/trino/connector/InternalMetadataProvider.java b/core/trino-main/src/main/java/io/trino/connector/InternalMetadataProvider.java index 2f76f9778429..5eeb0447bf74 100644 --- a/core/trino-main/src/main/java/io/trino/connector/InternalMetadataProvider.java +++ b/core/trino-main/src/main/java/io/trino/connector/InternalMetadataProvider.java @@ -13,6 +13,7 @@ */ package io.trino.connector; +import com.google.common.collect.ImmutableList; import io.trino.FullConnectorSession; import io.trino.Session; import io.trino.metadata.MaterializedViewDefinition; @@ -54,12 +55,12 @@ public Optional getRelationMetadata(ConnectorSession conne Optional materializedView = metadata.getMaterializedView(session, qualifiedName); if (materializedView.isPresent()) { - return Optional.of(new ConnectorTableSchema(tableName.getSchemaTableName(), toColumnSchema(materializedView.get().getColumns()))); + return Optional.of(new ConnectorTableSchema(tableName.getSchemaTableName(), toColumnSchema(materializedView.get().getColumns()), ImmutableList.of())); } Optional view = metadata.getView(session, qualifiedName); if (view.isPresent()) { - return Optional.of(new ConnectorTableSchema(tableName.getSchemaTableName(), toColumnSchema(view.get().getColumns()))); + return Optional.of(new ConnectorTableSchema(tableName.getSchemaTableName(), toColumnSchema(view.get().getColumns()), ImmutableList.of())); } Optional tableHandle = metadata.getTableHandle(session, qualifiedName); diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index 2a95b022d927..0e43dbaaf58d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -89,7 +89,6 @@ import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; import java.util.Deque; import java.util.HashSet; import java.util.LinkedHashMap; @@ -112,6 +111,7 @@ import static java.lang.Boolean.FALSE; import static java.lang.String.format; import static java.util.Collections.emptyList; +import static java.util.Collections.unmodifiableList; import static java.util.Collections.unmodifiableMap; import static java.util.Collections.unmodifiableSet; import static java.util.Objects.requireNonNull; @@ -212,6 +212,7 @@ public class Analysis private final Multiset rowFilterScopes = HashMultiset.create(); private final Map, List> rowFilters = new LinkedHashMap<>(); + private final Map, List> checkConstraints = new LinkedHashMap<>(); private final Multiset columnMaskScopes = HashMultiset.create(); private final Map, Map>> columnMasks = new LinkedHashMap<>(); @@ -1071,11 +1072,22 @@ public void addRowFilter(Table table, Expression filter) .add(filter); } + public void addCheckConstraints(Table table, Expression constraint) + { + checkConstraints.computeIfAbsent(NodeRef.of(table), node -> new ArrayList<>()) + .add(constraint); + } + public List getRowFilters(Table node) { return rowFilters.getOrDefault(NodeRef.of(node), ImmutableList.of()); } + public List getCheckConstraints(Table node) + { + return unmodifiableList(checkConstraints.getOrDefault(NodeRef.of(node), ImmutableList.of())); + } + public boolean hasColumnMask(QualifiedObjectName table, String column, String identity) { return columnMaskScopes.contains(new ColumnMaskScopeEntry(table, column, identity)); @@ -1572,22 +1584,22 @@ public void addQuantifiedComparisons(List expres public List getInPredicatesSubqueries() { - return Collections.unmodifiableList(inPredicatesSubqueries); + return unmodifiableList(inPredicatesSubqueries); } public List getSubqueries() { - return Collections.unmodifiableList(subqueries); + return unmodifiableList(subqueries); } public List getExistsSubqueries() { - return Collections.unmodifiableList(existsSubqueries); + return unmodifiableList(existsSubqueries); } public List getQuantifiedComparisonSubqueries() { - return Collections.unmodifiableList(quantifiedComparisonSubqueries); + return unmodifiableList(quantifiedComparisonSubqueries); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 9d9781016ea5..f8cf289eed0e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -115,7 +115,6 @@ import io.trino.sql.analyzer.Scope.AsteriskedIdentifierChainBasis; import io.trino.sql.parser.ParsingException; import io.trino.sql.parser.SqlParser; -import io.trino.sql.planner.DeterminismEvaluator; import io.trino.sql.planner.ExpressionInterpreter; import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.planner.ScopeAware; @@ -297,6 +296,7 @@ import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_FOUND; import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_WINDOW; import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; +import static io.trino.spi.StandardErrorCode.INVALID_CHECK_CONSTRAINT; import static io.trino.spi.StandardErrorCode.INVALID_COLUMN_REFERENCE; import static io.trino.spi.StandardErrorCode.INVALID_COPARTITIONING; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; @@ -365,6 +365,8 @@ import static io.trino.sql.analyzer.SemanticExceptions.semanticException; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.analyzer.TypeSignatureTranslator.toTypeSignature; +import static io.trino.sql.planner.DeterminismEvaluator.containsCurrentTimeFunctions; +import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic; import static io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; import static io.trino.sql.tree.DereferenceExpression.getQualifiedName; @@ -542,6 +544,7 @@ protected Scope visitInsert(Insert insert, Optional scope) List columns = tableSchema.getColumns().stream() .filter(column -> !column.isHidden()) .collect(toImmutableList()); + List checkConstraints = tableSchema.getTableSchema().getCheckConstraints(); for (ColumnSchema column : columns) { if (!accessControl.getColumnMasks(session.toSecurityContext(), targetTable, column.getName(), column.getType()).isEmpty()) { @@ -551,7 +554,12 @@ protected Scope visitInsert(Insert insert, Optional scope) Map columnHandles = metadata.getColumnHandles(session, targetTableHandle.get()); List tableFields = analyzeTableOutputFields(insert.getTable(), targetTable, tableSchema, columnHandles); - analyzeFiltersAndMasks(insert.getTable(), targetTable, targetTableHandle, tableFields, session.getIdentity().getUser()); + Scope accessControlScope = Scope.builder() + .withRelationType(RelationId.anonymous(), new RelationType(tableFields)) + .build(); + analyzeFiltersAndMasks(insert.getTable(), targetTable, new RelationType(tableFields), accessControlScope); + analyzeCheckConstraints(insert.getTable(), targetTable, accessControlScope, checkConstraints); + analysis.registerTable(insert.getTable(), targetTableHandle, targetTable, session.getIdentity().getUser(), accessControlScope); List tableColumns = columns.stream() .map(ColumnSchema::getName) @@ -801,7 +809,12 @@ protected Scope visitDelete(Delete node, Optional scope) analysis.setUpdateType("DELETE"); analysis.setUpdateTarget(tableName, Optional.of(table), Optional.empty()); - analyzeFiltersAndMasks(table, tableName, Optional.of(handle), analysis.getScope(table).getRelationType(), session.getIdentity().getUser()); + Scope accessControlScope = Scope.builder() + .withRelationType(RelationId.anonymous(), analysis.getScope(table).getRelationType()) + .build(); + analyzeFiltersAndMasks(table, tableName, analysis.getScope(table).getRelationType(), accessControlScope); + analyzeCheckConstraints(table, tableName, accessControlScope, tableSchema.getTableSchema().getCheckConstraints()); + analysis.registerTable(table, Optional.of(handle), tableName, session.getIdentity().getUser(), accessControlScope); createMergeAnalysis(table, handle, tableSchema, tableScope, tableScope, ImmutableList.of()); @@ -2188,7 +2201,12 @@ protected Scope visitTable(Table table, Optional scope) List outputFields = fields.build(); - analyzeFiltersAndMasks(table, targetTableName, tableHandle, outputFields, session.getIdentity().getUser()); + Scope accessControlScope = Scope.builder() + .withRelationType(RelationId.anonymous(), new RelationType(outputFields)) + .build(); + analyzeFiltersAndMasks(table, targetTableName, new RelationType(outputFields), accessControlScope); + analyzeCheckConstraints(table, targetTableName, accessControlScope, tableSchema.getTableSchema().getCheckConstraints()); + analysis.registerTable(table, tableHandle, targetTableName, session.getIdentity().getUser(), accessControlScope); Scope tableScope = createAndAssignScope(table, scope, outputFields); @@ -2208,17 +2226,8 @@ private void checkStorageTableNotRedirected(QualifiedObjectName source) }); } - private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Optional tableHandle, List fields, String authorization) - { - analyzeFiltersAndMasks(table, name, tableHandle, new RelationType(fields), authorization); - } - - private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Optional tableHandle, RelationType relationType, String authorization) + private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, RelationType relationType, Scope accessControlScope) { - Scope accessControlScope = Scope.builder() - .withRelationType(RelationId.anonymous(), relationType) - .build(); - for (int index = 0; index < relationType.getAllFieldCount(); index++) { Field field = relationType.getFieldByIndex(index); if (field.getName().isPresent()) { @@ -2232,8 +2241,14 @@ private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Optio accessControl.getRowFilters(session.toSecurityContext(), name) .forEach(filter -> analyzeRowFilter(session.getIdentity().getUser(), table, name, accessControlScope, filter)); + } - analysis.registerTable(table, tableHandle, name, authorization, accessControlScope); + private void analyzeCheckConstraints(Table table, QualifiedObjectName name, Scope accessControlScope, List constraints) + { + for (String constraint : constraints) { + ViewExpression expression = new ViewExpression(session.getIdentity().getUser(), Optional.of(name.getCatalogName()), Optional.of(name.getSchemaName()), constraint); + analyzeCheckConstraint(table, name, accessControlScope, expression); + } } private boolean checkCanSelectFromColumn(QualifiedObjectName name, String column) @@ -2375,13 +2390,21 @@ private Scope createScopeForView( if (storageTable.isPresent()) { List storageTableFields = analyzeStorageTable(table, viewFields, storageTable.get()); - analyzeFiltersAndMasks(table, name, storageTable, viewFields, session.getIdentity().getUser()); + Scope accessControlScope = Scope.builder() + .withRelationType(RelationId.anonymous(), new RelationType(viewFields)) + .build(); + analyzeFiltersAndMasks(table, name, new RelationType(viewFields), accessControlScope); + analysis.registerTable(table, storageTable, name, session.getIdentity().getUser(), accessControlScope); analysis.addRelationCoercion(table, viewFields.stream().map(Field::getType).toArray(Type[]::new)); // use storage table output fields as they contain ColumnHandles return createAndAssignScope(table, scope, storageTableFields); } - analyzeFiltersAndMasks(table, name, storageTable, viewFields, session.getIdentity().getUser()); + Scope accessControlScope = Scope.builder() + .withRelationType(RelationId.anonymous(), new RelationType(viewFields)) + .build(); + analyzeFiltersAndMasks(table, name, new RelationType(viewFields), accessControlScope); + analysis.registerTable(table, storageTable, name, session.getIdentity().getUser(), accessControlScope); viewFields.forEach(field -> analysis.addSourceColumns(field, ImmutableSet.of(new SourceColumn(name, field.getName().orElseThrow())))); analysis.registerNamedQuery(table, query); return createAndAssignScope(table, scope, viewFields); @@ -3174,6 +3197,10 @@ protected Scope visitUpdate(Update update, Optional scope) if (!accessControl.getRowFilters(session.toSecurityContext(), tableName).isEmpty()) { throw semanticException(NOT_SUPPORTED, update, "Updating a table with a row filter is not supported"); } + if (!tableSchema.getTableSchema().getCheckConstraints().isEmpty()) { + // TODO https://github.com/trinodb/trino/issues/15411 Add support for CHECK constraint to UPDATE statement + throw semanticException(NOT_SUPPORTED, update, "Updating a table with a check constraint is not supported"); + } // TODO: how to deal with connectors that need to see the pre-image of rows to perform the update without // flowing that data through the masking logic @@ -3301,6 +3328,10 @@ protected Scope visitMerge(Merge merge, Optional scope) if (!accessControl.getRowFilters(session.toSecurityContext(), tableName).isEmpty()) { throw semanticException(NOT_SUPPORTED, merge, "Cannot merge into a table with row filters"); } + if (!tableSchema.getTableSchema().getCheckConstraints().isEmpty()) { + // TODO https://github.com/trinodb/trino/issues/15411 Add support for CHECK constraint to MERGE statement + throw semanticException(NOT_SUPPORTED, merge, "Cannot merge into a table with check constraints"); + } Scope targetTableScope = analyzer.analyzeForUpdate(relation, scope, UpdateKind.MERGE); Scope sourceTableScope = process(merge.getSource(), scope); @@ -4646,6 +4677,62 @@ private void analyzeRowFilter(String currentIdentity, Table table, QualifiedObje analysis.addRowFilter(table, expression); } + private void analyzeCheckConstraint(Table table, QualifiedObjectName name, Scope scope, ViewExpression constraint) + { + Expression expression; + try { + expression = sqlParser.createExpression(constraint.getExpression(), createParsingOptions(session)); + } + catch (ParsingException e) { + throw new TrinoException(INVALID_CHECK_CONSTRAINT, extractLocation(table), format("Invalid check constraint for '%s': %s", name, e.getErrorMessage()), e); + } + + verifyNoAggregateWindowOrGroupingFunctions(session, metadata, expression, format("Check constraint for '%s'", name)); + + ExpressionAnalysis expressionAnalysis; + try { + Identity filterIdentity = Identity.forUser(constraint.getIdentity()) + .withGroups(groupProvider.getGroups(constraint.getIdentity())) + .build(); + expressionAnalysis = ExpressionAnalyzer.analyzeExpression( + createViewSession(constraint.getCatalog(), constraint.getSchema(), filterIdentity, session.getPath()), + plannerContext, + statementAnalyzerFactory, + accessControl, + scope, + analysis, + expression, + warningCollector, + correlationSupport); + } + catch (TrinoException e) { + throw new TrinoException(e::getErrorCode, extractLocation(table), format("Invalid check constraint for '%s': %s", name, e.getRawMessage()), e); + } + + // Ensure that the expression doesn't contain non-deterministic functions. This should be "retrospectively deterministic" per SQL standard. + if (!isDeterministic(expression, this::getResolvedFunction)) { + throw semanticException(INVALID_CHECK_CONSTRAINT, expression, "Check constraint expression should be deterministic"); + } + if (containsCurrentTimeFunctions(expression)) { + throw semanticException(INVALID_CHECK_CONSTRAINT, expression, "Check constraint expression should not contain temporal expression"); + } + + analysis.recordSubqueries(expression, expressionAnalysis); + + Type actualType = expressionAnalysis.getType(expression); + if (!actualType.equals(BOOLEAN)) { + TypeCoercion coercion = new TypeCoercion(plannerContext.getTypeManager()::getType); + + if (!coercion.canCoerce(actualType, BOOLEAN)) { + throw new TrinoException(TYPE_MISMATCH, extractLocation(table), format("Expected check constraint for '%s' to be of type BOOLEAN, but was %s", name, actualType), null); + } + + analysis.addCoercion(expression, BOOLEAN, coercion.isTypeOnlyCoercion(actualType, BOOLEAN)); + } + + analysis.addCheckConstraints(table, expression); + } + private void analyzeColumnMask(String currentIdentity, Table table, QualifiedObjectName tableName, Field field, Scope scope, ViewExpression mask) { String column = field.getName().orElseThrow(); @@ -5031,7 +5118,7 @@ private void verifySelectDistinct(QuerySpecification node, List orde } for (Expression expression : orderByExpressions) { - if (!DeterminismEvaluator.isDeterministic(expression, this::getResolvedFunction)) { + if (!isDeterministic(expression, this::getResolvedFunction)) { throw semanticException(EXPRESSION_NOT_IN_DISTINCT, expression, "Non deterministic ORDER BY expression is not supported with SELECT DISTINCT"); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/DeterminismEvaluator.java b/core/trino-main/src/main/java/io/trino/sql/planner/DeterminismEvaluator.java index 25a10a35cb82..ef00948c3acc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/DeterminismEvaluator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/DeterminismEvaluator.java @@ -15,6 +15,7 @@ import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; +import io.trino.sql.tree.CurrentTime; import io.trino.sql.tree.DefaultExpressionTraversalVisitor; import io.trino.sql.tree.Expression; import io.trino.sql.tree.FunctionCall; @@ -65,4 +66,24 @@ protected Void visitFunctionCall(FunctionCall node, AtomicBoolean deterministic) return super.visitFunctionCall(node, deterministic); } } + + public static boolean containsCurrentTimeFunctions(Expression expression) + { + requireNonNull(expression, "expression is null"); + + AtomicBoolean currentTime = new AtomicBoolean(false); + new CurrentTimeVisitor().process(expression, currentTime); + return currentTime.get(); + } + + private static class CurrentTimeVisitor + extends DefaultExpressionTraversalVisitor + { + @Override + protected Void visitCurrentTime(CurrentTime node, AtomicBoolean currentTime) + { + currentTime.set(true); + return super.visitCurrentTime(node, currentTime); + } + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index 44ef9c03aa70..c876c8cd9a22 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -498,6 +498,18 @@ private RelationPlan getInsertPlan( .withRelationType(accessControlScope.getRelationId(), accessControlScope.getRelationType().withOnlyVisibleFields()) .build(); }); + plan = planner.addCheckConstraints( + analysis.getCheckConstraints(table), + table, + plan, + node -> { + Scope accessControlScope = analysis.getAccessControlScope(table); + // hidden fields are not accessible in insert + return Scope.builder() + .like(accessControlScope) + .withRelationType(accessControlScope.getRelationId(), accessControlScope.getRelationType().withOnlyVisibleFields()) + .build(); + }); List insertedTableColumnNames = insertedColumns.stream() .map(ColumnMetadata::getName) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 2b0e420e27d8..3a62b9a801d5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -65,6 +65,7 @@ import io.trino.sql.tree.Except; import io.trino.sql.tree.Expression; import io.trino.sql.tree.Identifier; +import io.trino.sql.tree.IfExpression; import io.trino.sql.tree.Intersect; import io.trino.sql.tree.Join; import io.trino.sql.tree.JoinCriteria; @@ -113,10 +114,13 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; +import static io.trino.spi.StandardErrorCode.CONSTRAINT_VIOLATION; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; +import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.NodeUtils.getSortItemsFromOrderBy; import static io.trino.sql.analyzer.SemanticExceptions.semanticException; import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; +import static io.trino.sql.planner.LogicalPlanner.failFunction; import static io.trino.sql.planner.PlanBuilder.newPlanBuilder; import static io.trino.sql.planner.QueryPlanner.coerce; import static io.trino.sql.planner.QueryPlanner.coerceIfNecessary; @@ -291,6 +295,33 @@ public RelationPlan addRowFilters(Table node, RelationPlan plan, Function constraints, Table node, RelationPlan plan, Function accessControlScope) + { + if (constraints.isEmpty()) { + return plan; + } + + PlanBuilder planBuilder = newPlanBuilder(plan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext) + .withScope(accessControlScope.apply(node), plan.getFieldMappings()); // The fields in the access control scope has the same layout as those for the table scope + + for (Expression constraint : constraints) { + planBuilder = subqueryPlanner.handleSubqueries(planBuilder, constraint, analysis.getSubqueries(constraint)); + + Expression predicate = new IfExpression( + // When predicate evaluates to UNKNOWN (e.g. NULL > 100), it should not violate the check constraint. + new CoalesceExpression(coerceIfNecessary(analysis, constraint, planBuilder.rewrite(constraint)), TRUE_LITERAL), + TRUE_LITERAL, + new Cast(failFunction(plannerContext.getMetadata(), session, CONSTRAINT_VIOLATION, "Check constraint violation: " + constraint), toSqlType(BOOLEAN))); + + planBuilder = planBuilder.withNewRoot(new FilterNode( + idAllocator.getNextId(), + planBuilder.getRoot(), + predicate)); + } + + return new RelationPlan(planBuilder.getRoot(), plan.getScope(), plan.getFieldMappings(), outerContext); + } + private RelationPlan addColumnMasks(Table table, RelationPlan plan) { Map> columnMasks = analysis.getColumnMasks(table); diff --git a/core/trino-main/src/test/java/io/trino/connector/MockConnector.java b/core/trino-main/src/test/java/io/trino/connector/MockConnector.java index edc17d785ef1..7abea1fcca83 100644 --- a/core/trino-main/src/test/java/io/trino/connector/MockConnector.java +++ b/core/trino-main/src/test/java/io/trino/connector/MockConnector.java @@ -136,6 +136,7 @@ public class MockConnector private final BiFunction getTableHandle; private final Function> getColumns; private final Function getTableStatistics; + private final Function> checkConstraints; private final MockConnectorFactory.ApplyProjection applyProjection; private final MockConnectorFactory.ApplyAggregation applyAggregation; private final MockConnectorFactory.ApplyJoin applyJoin; @@ -177,6 +178,7 @@ public class MockConnector BiFunction getTableHandle, Function> getColumns, Function getTableStatistics, + Function> checkConstraints, ApplyProjection applyProjection, ApplyAggregation applyAggregation, ApplyJoin applyJoin, @@ -216,6 +218,7 @@ public class MockConnector this.getTableHandle = requireNonNull(getTableHandle, "getTableHandle is null"); this.getColumns = requireNonNull(getColumns, "getColumns is null"); this.getTableStatistics = requireNonNull(getTableStatistics, "getTableStatistics is null"); + this.checkConstraints = requireNonNull(checkConstraints, "checkConstraints is null"); this.applyProjection = requireNonNull(applyProjection, "applyProjection is null"); this.applyAggregation = requireNonNull(applyAggregation, "applyAggregation is null"); this.applyJoin = requireNonNull(applyJoin, "applyJoin is null"); @@ -465,7 +468,12 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle) { MockConnectorTableHandle table = (MockConnectorTableHandle) tableHandle; - return new ConnectorTableMetadata(table.getTableName(), getColumns.apply(table.getTableName())); + return new ConnectorTableMetadata( + table.getTableName(), + getColumns.apply(table.getTableName()), + ImmutableMap.of(), + Optional.empty(), + checkConstraints.apply(table.getTableName())); } @Override diff --git a/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java b/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java index 8f6eaebdab52..c722c0f28fb3 100644 --- a/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java +++ b/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java @@ -93,6 +93,7 @@ public class MockConnectorFactory private final BiFunction getTableHandle; private final Function> getColumns; private final Function getTableStatistics; + private final Function> checkConstraints; private final ApplyProjection applyProjection; private final ApplyAggregation applyAggregation; private final ApplyJoin applyJoin; @@ -136,6 +137,7 @@ private MockConnectorFactory( BiFunction getTableHandle, Function> getColumns, Function getTableStatistics, + Function> checkConstraints, ApplyProjection applyProjection, ApplyAggregation applyAggregation, ApplyJoin applyJoin, @@ -176,6 +178,7 @@ private MockConnectorFactory( this.getTableHandle = requireNonNull(getTableHandle, "getTableHandle is null"); this.getColumns = requireNonNull(getColumns, "getColumns is null"); this.getTableStatistics = requireNonNull(getTableStatistics, "getTableStatistics is null"); + this.checkConstraints = requireNonNull(checkConstraints, "checkConstraints is null"); this.applyProjection = requireNonNull(applyProjection, "applyProjection is null"); this.applyAggregation = requireNonNull(applyAggregation, "applyAggregation is null"); this.applyJoin = requireNonNull(applyJoin, "applyJoin is null"); @@ -226,6 +229,7 @@ public Connector create(String catalogName, Map config, Connecto getTableHandle, getColumns, getTableStatistics, + checkConstraints, applyProjection, applyAggregation, applyJoin, @@ -353,6 +357,7 @@ public static final class Builder private BiFunction getTableHandle = defaultGetTableHandle(); private Function> getColumns = defaultGetColumns(); private Function getTableStatistics = schemaTableName -> empty(); + private Function> checkConstraints = (schemaTableName -> ImmutableList.of()); private ApplyProjection applyProjection = (session, handle, projections, assignments) -> Optional.empty(); private ApplyAggregation applyAggregation = (session, handle, aggregates, assignments, groupingSets) -> Optional.empty(); private ApplyJoin applyJoin = (session, joinType, left, right, joinConditions, leftAssignments, rightAssignments) -> Optional.empty(); @@ -474,6 +479,12 @@ public Builder withGetTableStatistics(Function return this; } + public Builder withCheckConstraints(Function> checkConstraints) + { + this.checkConstraints = requireNonNull(checkConstraints, "checkConstraints is null"); + return this; + } + public Builder withApplyProjection(ApplyProjection applyProjection) { this.applyProjection = applyProjection; @@ -683,6 +694,7 @@ public MockConnectorFactory build() getTableHandle, getColumns, getTableStatistics, + checkConstraints, applyProjection, applyAggregation, applyJoin, diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestCheckConstraint.java b/core/trino-main/src/test/java/io/trino/sql/query/TestCheckConstraint.java new file mode 100644 index 000000000000..b9fe3b011c6a --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestCheckConstraint.java @@ -0,0 +1,439 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.query; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.Session; +import io.trino.connector.MockConnectorFactory; +import io.trino.plugin.tpch.TpchConnectorFactory; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.security.Identity; +import io.trino.testing.LocalQueryRunner; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import static io.trino.connector.MockConnectorEntities.TPCH_NATION_DATA; +import static io.trino.connector.MockConnectorEntities.TPCH_NATION_SCHEMA; +import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; + +@TestInstance(PER_CLASS) +@Execution(SAME_THREAD) +public class TestCheckConstraint +{ + private static final String LOCAL_CATALOG = "local"; + private static final String MOCK_CATALOG = "mock"; + private static final String USER = "user"; + + private static final Session SESSION = testSessionBuilder() + .setCatalog(LOCAL_CATALOG) + .setSchema(TINY_SCHEMA_NAME) + .setIdentity(Identity.forUser(USER).build()) + .build(); + + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + LocalQueryRunner runner = LocalQueryRunner.builder(SESSION).build(); + + runner.createCatalog(LOCAL_CATALOG, new TpchConnectorFactory(1), ImmutableMap.of()); + + MockConnectorFactory mock = MockConnectorFactory.builder() + .withGetColumns(schemaTableName -> { + if (schemaTableName.equals(new SchemaTableName("tiny", "nation"))) { + return TPCH_NATION_SCHEMA; + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_multiple_column_constraint"))) { + return TPCH_NATION_SCHEMA; + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_invalid_function"))) { + return TPCH_NATION_SCHEMA; + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_not_boolean_expression"))) { + return TPCH_NATION_SCHEMA; + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_subquery"))) { + return TPCH_NATION_SCHEMA; + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_current_date"))) { + return TPCH_NATION_SCHEMA; + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_current_time"))) { + return TPCH_NATION_SCHEMA; + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_current_timestamp"))) { + return TPCH_NATION_SCHEMA; + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_localtime"))) { + return TPCH_NATION_SCHEMA; + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_localtimestamp"))) { + return TPCH_NATION_SCHEMA; + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_not_deterministic"))) { + return TPCH_NATION_SCHEMA; + } + throw new UnsupportedOperationException(); + }) + .withCheckConstraints(schemaTableName -> { + if (schemaTableName.equals(new SchemaTableName("tiny", "nation"))) { + return ImmutableList.of("regionkey < 10"); + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_multiple_column_constraint"))) { + return ImmutableList.of("nationkey > 100 AND regionkey > 50"); + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_invalid_function"))) { + return ImmutableList.of("invalid_function(nationkey) > 100"); + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_not_boolean_expression"))) { + return ImmutableList.of("1 + 1"); + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_subquery"))) { + return ImmutableList.of("nationkey > (SELECT count(*) FROM nation)"); + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_current_date"))) { + return ImmutableList.of("CURRENT_DATE > DATE '2022-12-31'"); + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_current_time"))) { + return ImmutableList.of("CURRENT_TIME > TIME '12:34:56.123+00:00'"); + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_current_timestamp"))) { + return ImmutableList.of("CURRENT_TIMESTAMP > TIMESTAMP '2022-12-31 23:59:59'"); + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_localtime"))) { + return ImmutableList.of("LOCALTIME > TIME '12:34:56.123'"); + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_localtimestamp"))) { + return ImmutableList.of("LOCALTIMESTAMP > TIMESTAMP '2022-12-31 23:59:59'"); + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_not_deterministic"))) { + return ImmutableList.of("nationkey > random()"); + } + throw new UnsupportedOperationException(); + }) + .withData(schemaTableName -> { + if (schemaTableName.equals(new SchemaTableName("tiny", "nation"))) { + return TPCH_NATION_DATA; + } + if (schemaTableName.equals(new SchemaTableName("tiny", "nation_multiple_column_constraint"))) { + return TPCH_NATION_DATA; + } + throw new UnsupportedOperationException(); + }) + .build(); + + runner.createCatalog(MOCK_CATALOG, mock, ImmutableMap.of()); + + assertions = new QueryAssertions(runner); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } + + /** + * @see #testMergeInsert() + */ + @Test + public void testInsert() + { + assertThat(assertions.query("INSERT INTO mock.tiny.nation VALUES (101, 'POLAND', 0, 'No comment')")) + .matches("SELECT BIGINT '1'"); + + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation VALUES (26, 'POLAND', 11, 'No comment')")) + .hasMessage("Check constraint violation: (regionkey < 10)"); + assertThatThrownBy(() -> assertions.query(""" + INSERT INTO mock.tiny.nation VALUES + (26, 'POLAND', 11, 'No comment'), + (27, 'HOLLAND', 11, 'A comment') + """)) + .hasMessage("Check constraint violation: (regionkey < 10)"); + assertThatThrownBy(() -> assertions.query(""" + INSERT INTO mock.tiny.nation VALUES + (26, 'POLAND', 11, 'No comment'), + (27, 'HOLLAND', 11, 'A comment') + """)) + .hasMessage("Check constraint violation: (regionkey < 10)"); + } + + /** + * Like {@link #testInsert} but using the MERGE statement. + */ + @Test + public void testMergeInsert() + { + // Within allowed check constraint + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES 42) t(dummy) ON false + WHEN NOT MATCHED THEN INSERT VALUES (101, 'POLAND', 0, 'No comment') + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + + // Outside allowed check constraint + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES 42) t(dummy) ON false + WHEN NOT MATCHED THEN INSERT VALUES (26, 'POLAND', 0, 'No comment') + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES (26, 'POLAND', 0, 'No comment'), (27, 'HOLLAND', 0, 'A comment')) t(a,b,c,d) ON nationkey = a + WHEN NOT MATCHED THEN INSERT VALUES (a,b,c,d) + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES 42) t(dummy) ON false + WHEN NOT MATCHED THEN INSERT (nationkey) VALUES (NULL) + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES 42) t(dummy) ON false + WHEN NOT MATCHED THEN INSERT (nationkey) VALUES (0) + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + } + + @Test + public void testInsertAllowUnknown() + { + // Predicate evaluates to UNKNOWN (e.g. NULL > 100) should not violate check constraint + assertThat(assertions.query("INSERT INTO mock.tiny.nation(nationkey) VALUES (null)")) + .matches("SELECT BIGINT '1'"); + assertThat(assertions.query("INSERT INTO mock.tiny.nation(regionkey) VALUES (0)")) + .matches("SELECT BIGINT '1'"); + } + + @Test + public void testInsertCheckMultipleColumns() + { + assertThat(assertions.query("INSERT INTO mock.tiny.nation_multiple_column_constraint VALUES (101, 'POLAND', 51, 'No comment')")) + .matches("SELECT BIGINT '1'"); + + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation_multiple_column_constraint VALUES (101, 'POLAND', 50, 'No comment')")) + .hasMessage("Check constraint violation: ((nationkey > 100) AND (regionkey > 50))"); + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation_multiple_column_constraint VALUES (100, 'POLAND', 51, 'No comment')")) + .hasMessage("Check constraint violation: ((nationkey > 100) AND (regionkey > 50))"); + } + + @Test + public void testInsertSubquery() + { + assertThat(assertions.query("INSERT INTO mock.tiny.nation_subquery VALUES (26, 'POLAND', 51, 'No comment')")) + .matches("SELECT BIGINT '1'"); + + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation_subquery VALUES (10, 'POLAND', 0, 'No comment')")) + .hasMessage("Check constraint violation: (nationkey > (SELECT count(*)\nFROM\n nation\n))"); + } + + @Test + public void testInsertUnsupportedCurrentDate() + { + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation_current_date VALUES (101, 'POLAND', 0, 'No comment')")) + .hasMessageContaining("Check constraint expression should not contain temporal expression"); + } + + @Test + public void testInsertUnsupportedCurrentTime() + { + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation_current_time VALUES (101, 'POLAND', 0, 'No comment')")) + .hasMessageContaining("Check constraint expression should not contain temporal expression"); + } + + @Test + public void testInsertUnsupportedCurrentTimestamp() + { + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation_current_timestamp VALUES (101, 'POLAND', 0, 'No comment')")) + .hasMessageContaining("Check constraint expression should not contain temporal expression"); + } + + @Test + public void testInsertUnsupportedLocaltime() + { + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation_localtime VALUES (101, 'POLAND', 0, 'No comment')")) + .hasMessageContaining("Check constraint expression should not contain temporal expression"); + } + + @Test + public void testInsertUnsupportedLocaltimestamp() + { + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation_localtimestamp VALUES (101, 'POLAND', 0, 'No comment')")) + .hasMessageContaining("Check constraint expression should not contain temporal expression"); + } + + @Test + public void testInsertUnsupportedConstraint() + { + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation_invalid_function VALUES (101, 'POLAND', 0, 'No comment')")) + .hasMessageContaining("Function 'invalid_function' not registered"); + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation_not_boolean_expression VALUES (101, 'POLAND', 0, 'No comment')")) + .hasMessageContaining("to be of type BOOLEAN, but was integer"); + } + + @Test + public void testInsertNotDeterministic() + { + assertThatThrownBy(() -> assertions.query("INSERT INTO mock.tiny.nation_not_deterministic VALUES (100, 'POLAND', 0, 'No comment')")) + .hasMessageContaining("Check constraint expression should be deterministic"); + } + + /** + * @see #testMergeDelete() + */ + @Test + public void testDelete() + { + assertThat(assertions.query("DELETE FROM mock.tiny.nation WHERE nationkey < 3")) + .matches("SELECT BIGINT '3'"); + assertThat(assertions.query("DELETE FROM mock.tiny.nation WHERE nationkey IN (1, 2, 3)")) + .matches("SELECT BIGINT '3'"); + } + + /** + * Like {@link #testDelete()} but using the MERGE statement. + */ + @Test + public void testMergeDelete() + { + // Within allowed check constraint + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES 1,2) t(x) ON nationkey = x + WHEN MATCHED THEN DELETE + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + + // Outside allowed check constraint + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES 1,2,3,4,5) t(x) ON regionkey = x + WHEN MATCHED THEN DELETE + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES 1,11) t(x) ON nationkey = x + WHEN MATCHED THEN DELETE + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES 11,12,13,14,15) t(x) ON nationkey = x + WHEN MATCHED THEN DELETE + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + } + + /** + * @see #testMergeUpdate() + */ + @Test + public void testUpdate() + { + // Within allowed check constraint + assertThatThrownBy(() -> assertions.query("UPDATE mock.tiny.nation SET regionkey = regionkey * 2 WHERE nationkey < 3")) + .hasMessage("line 1:1: Updating a table with a check constraint is not supported"); + assertThatThrownBy(() -> assertions.query("UPDATE mock.tiny.nation SET regionkey = regionkey * 2 WHERE nationkey IN (1, 2, 3)")) + .hasMessage("line 1:1: Updating a table with a check constraint is not supported"); + + // Outside allowed check constraint + assertThatThrownBy(() -> assertions.query("UPDATE mock.tiny.nation SET regionkey = regionkey * 2")) + .hasMessage("line 1:1: Updating a table with a check constraint is not supported"); + assertThatThrownBy(() -> assertions.query("UPDATE mock.tiny.nation SET regionkey = regionkey * 2 WHERE nationkey IN (1, 11)")) + .hasMessage("line 1:1: Updating a table with a check constraint is not supported"); + + assertThatThrownBy(() -> assertions.query("UPDATE mock.tiny.nation SET regionkey = regionkey * 2 WHERE nationkey = 11")) + .hasMessage("line 1:1: Updating a table with a check constraint is not supported"); + + // Within allowed check constraint, but updated rows are outside the check constraint + assertThatThrownBy(() -> assertions.query("UPDATE mock.tiny.nation SET nationkey = 10 WHERE nationkey < 3")) + .hasMessage("line 1:1: Updating a table with a check constraint is not supported"); + assertThatThrownBy(() -> assertions.query("UPDATE mock.tiny.nation SET nationkey = null WHERE nationkey < 3")) + .hasMessage("line 1:1: Updating a table with a check constraint is not supported"); + + // Outside allowed check constraint, and updated rows are outside the check constraint + assertThatThrownBy(() -> assertions.query("UPDATE mock.tiny.nation SET nationkey = 10 WHERE nationkey = 10")) + .hasMessage("line 1:1: Updating a table with a check constraint is not supported"); + assertThatThrownBy(() -> assertions.query("UPDATE mock.tiny.nation SET nationkey = null WHERE nationkey = null ")) + .hasMessage("line 1:1: Updating a table with a check constraint is not supported"); + } + + /** + * Like {@link #testUpdate()} but using the MERGE statement. + */ + @Test + public void testMergeUpdate() + { + // Within allowed check constraint + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES 5) t(x) ON nationkey = x + WHEN MATCHED THEN UPDATE SET regionkey = regionkey * 2 + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + + // Outside allowed check constraint + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES 1,2,3,4,5,6) t(x) ON regionkey = x + WHEN MATCHED THEN UPDATE SET regionkey = regionkey * 2 + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES 1, 11) t(x) ON nationkey = x + WHEN MATCHED THEN UPDATE SET regionkey = regionkey * 2 + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES 11) t(x) ON nationkey = x + WHEN MATCHED THEN UPDATE SET regionkey = regionkey * 2 + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + + // Within allowed check constraint, but updated rows are outside the check constraint + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES 1,2,3) t(x) ON nationkey = x + WHEN MATCHED THEN UPDATE SET nationkey = 10 + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES 1,2,3) t(x) ON nationkey = x + WHEN MATCHED THEN UPDATE SET nationkey = NULL + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + + // Outside allowed check constraint, but updated rows are outside the check constraint + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES 10) t(x) ON nationkey = x + WHEN MATCHED THEN UPDATE SET nationkey = 13 + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES 10) t(x) ON nationkey = x + WHEN MATCHED THEN UPDATE SET nationkey = NULL + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + assertThatThrownBy(() -> assertions.query(""" + MERGE INTO mock.tiny.nation USING (VALUES 10) t(x) ON nationkey IS NULL + WHEN MATCHED THEN UPDATE SET nationkey = 13 + """)) + .hasMessage("line 1:1: Cannot merge into a table with check constraints"); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java index cf5892e56ca9..f40589aab433 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java +++ b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java @@ -144,6 +144,7 @@ public enum StandardErrorCode INVALID_COPARTITIONING(120, USER_ERROR), INVALID_TABLE_FUNCTION_INVOCATION(121, USER_ERROR), DUPLICATE_RANGE_VARIABLE(122, USER_ERROR), + INVALID_CHECK_CONSTRAINT(123, USER_ERROR), GENERIC_INTERNAL_ERROR(65536, INTERNAL_ERROR), TOO_MANY_REQUESTS_FAILED(65537, INTERNAL_ERROR), diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableMetadata.java index cb1963095098..2de63797c2c2 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableMetadata.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableMetadata.java @@ -13,6 +13,8 @@ */ package io.trino.spi.connector; +import io.trino.spi.Experimental; + import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; @@ -29,6 +31,7 @@ public class ConnectorTableMetadata private final Optional comment; private final List columns; private final Map properties; + private final List checkConstraints; public ConnectorTableMetadata(SchemaTableName table, List columns) { @@ -37,19 +40,27 @@ public ConnectorTableMetadata(SchemaTableName table, List column public ConnectorTableMetadata(SchemaTableName table, List columns, Map properties) { - this(table, columns, properties, Optional.empty()); + this(table, columns, properties, Optional.empty(), List.of()); } public ConnectorTableMetadata(SchemaTableName table, List columns, Map properties, Optional comment) + { + this(table, columns, properties, comment, List.of()); + } + + @Experimental(eta = "2023-03-31") + public ConnectorTableMetadata(SchemaTableName table, List columns, Map properties, Optional comment, List checkConstraints) { requireNonNull(table, "table is null"); requireNonNull(columns, "columns is null"); requireNonNull(comment, "comment is null"); + requireNonNull(checkConstraints, "checkConstraints is null"); this.table = table; this.columns = List.copyOf(columns); this.properties = Collections.unmodifiableMap(new LinkedHashMap<>(properties)); this.comment = comment; + this.checkConstraints = List.copyOf(checkConstraints); } public SchemaTableName getTable() @@ -72,13 +83,27 @@ public Optional getComment() return comment; } + /** + * List of constraints data in a table is expected to satisfy. + * Engine ensures rows written to a table meet these constraints. + * A check constraint is satisfied when it evaluates to True or Unknown. + * + * @return List of string representation of a Trino SQL scalar expression that can refer to table columns by name and produces a result coercible to boolean + */ + @Experimental(eta = "2023-03-31") + public List getCheckConstraints() + { + return checkConstraints; + } + public ConnectorTableSchema getTableSchema() { return new ConnectorTableSchema( table, columns.stream() .map(ColumnMetadata::getColumnSchema) - .collect(toUnmodifiableList())); + .collect(toUnmodifiableList()), + checkConstraints); } @Override @@ -89,6 +114,7 @@ public String toString() sb.append(", columns=").append(columns); sb.append(", properties=").append(properties); comment.ifPresent(value -> sb.append(", comment='").append(value).append("'")); + sb.append(", checkConstraints=").append(checkConstraints); sb.append('}'); return sb.toString(); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableSchema.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableSchema.java index afa9694ecb66..186f1c733a0c 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableSchema.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableSchema.java @@ -13,6 +13,8 @@ */ package io.trino.spi.connector; +import io.trino.spi.Experimental; + import java.util.List; import static java.util.Objects.requireNonNull; @@ -21,14 +23,23 @@ public class ConnectorTableSchema { private final SchemaTableName table; private final List columns; + private final List checkConstraints; public ConnectorTableSchema(SchemaTableName table, List columns) + { + this(table, columns, List.of()); + } + + @Experimental(eta = "2023-03-31") + public ConnectorTableSchema(SchemaTableName table, List columns, List checkConstraints) { requireNonNull(table, "table is null"); requireNonNull(columns, "columns is null"); + requireNonNull(checkConstraints, "checkConstraints is null"); this.table = table; this.columns = List.copyOf(columns); + this.checkConstraints = List.copyOf(checkConstraints); } public SchemaTableName getTable() @@ -41,12 +52,26 @@ public List getColumns() return columns; } + /** + * List of constraints data in a table is expected to satisfy. + * Engine ensures rows written to a table meet these constraints. + * A check constraint is satisfied when it evaluates to True or Unknown. + * + * @return List of string representation of a Trino SQL scalar expression that can refer to table columns by name and produces a result coercible to boolean + */ + @Experimental(eta = "2023-03-31") + public List getCheckConstraints() + { + return checkConstraints; + } + @Override public String toString() { return new StringBuilder("ConnectorTableSchema{") .append("table=").append(table) .append(", columns=").append(columns) + .append(", checkConstraints=").append(checkConstraints) .append('}') .toString(); }