Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.connector;

import com.google.common.collect.ImmutableList;
import io.trino.FullConnectorSession;
import io.trino.Session;
import io.trino.metadata.MaterializedViewDefinition;
Expand Down Expand Up @@ -54,12 +55,12 @@ public Optional<ConnectorTableSchema> getRelationMetadata(ConnectorSession conne

Optional<MaterializedViewDefinition> materializedView = metadata.getMaterializedView(session, qualifiedName);
if (materializedView.isPresent()) {
return Optional.of(new ConnectorTableSchema(tableName.getSchemaTableName(), toColumnSchema(materializedView.get().getColumns())));
return Optional.of(new ConnectorTableSchema(tableName.getSchemaTableName(), toColumnSchema(materializedView.get().getColumns()), ImmutableList.of()));
}

Optional<ViewDefinition> view = metadata.getView(session, qualifiedName);
if (view.isPresent()) {
return Optional.of(new ConnectorTableSchema(tableName.getSchemaTableName(), toColumnSchema(view.get().getColumns())));
return Optional.of(new ConnectorTableSchema(tableName.getSchemaTableName(), toColumnSchema(view.get().getColumns()), ImmutableList.of()));
}

Optional<TableHandle> tableHandle = metadata.getTableHandle(session, qualifiedName);
Expand Down
12 changes: 12 additions & 0 deletions core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ public class Analysis

private final Multiset<RowFilterScopeEntry> rowFilterScopes = HashMultiset.create();
private final Map<NodeRef<Table>, List<Expression>> rowFilters = new LinkedHashMap<>();
private final Map<NodeRef<Table>, List<Expression>> checkConstraints = new LinkedHashMap<>();

private final Multiset<ColumnMaskScopeEntry> columnMaskScopes = HashMultiset.create();
private final Map<NodeRef<Table>, Map<String, List<Expression>>> columnMasks = new LinkedHashMap<>();
Expand Down Expand Up @@ -1070,11 +1071,22 @@ public void addRowFilter(Table table, Expression filter)
.add(filter);
}

public void addCheckConstraints(Table table, Expression filter)
{
checkConstraints.computeIfAbsent(NodeRef.of(table), node -> new ArrayList<>())
.add(filter);
}

public List<Expression> getRowFilters(Table node)
{
return rowFilters.getOrDefault(NodeRef.of(node), ImmutableList.of());
}

public List<Expression> getCheckConstraints(Table node)
{
return checkConstraints.getOrDefault(NodeRef.of(node), ImmutableList.of());
}

public boolean hasColumnMask(QualifiedObjectName table, String column, String identity)
{
return columnMaskScopes.contains(new ColumnMaskScopeEntry(table, column, identity));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ protected Scope visitInsert(Insert insert, Optional<Scope> scope)
List<ColumnSchema> columns = tableSchema.getColumns().stream()
.filter(column -> !column.isHidden())
.collect(toImmutableList());
List<String> checkConstraints = tableSchema.getTableSchema().getCheckConstraints();

for (ColumnSchema column : columns) {
if (!accessControl.getColumnMasks(session.toSecurityContext(), targetTable, column.getName(), column.getType()).isEmpty()) {
Expand All @@ -544,7 +545,7 @@ protected Scope visitInsert(Insert insert, Optional<Scope> scope)

Map<String, ColumnHandle> columnHandles = metadata.getColumnHandles(session, targetTableHandle.get());
List<Field> tableFields = analyzeTableOutputFields(insert.getTable(), targetTable, tableSchema, columnHandles);
analyzeFiltersAndMasks(insert.getTable(), targetTable, targetTableHandle, tableFields, session.getIdentity().getUser());
analyzeFiltersAndMasks(insert.getTable(), targetTable, targetTableHandle, new RelationType(tableFields), session.getIdentity().getUser(), checkConstraints);

List<String> tableColumns = columns.stream()
.map(ColumnSchema::getName)
Expand Down Expand Up @@ -791,7 +792,7 @@ protected Scope visitDelete(Delete node, Optional<Scope> scope)

analysis.setUpdateType("DELETE");
analysis.setUpdateTarget(tableName, Optional.of(table), Optional.empty());
analyzeFiltersAndMasks(table, tableName, Optional.of(handle), analysis.getScope(table).getRelationType(), session.getIdentity().getUser());
analyzeFiltersAndMasks(table, tableName, Optional.of(handle), analysis.getScope(table).getRelationType(), session.getIdentity().getUser(), ImmutableList.of());

return createAndAssignScope(node, scope, Field.newUnqualified("rows", BIGINT));
}
Expand Down Expand Up @@ -2145,10 +2146,10 @@ private void checkStorageTableNotRedirected(QualifiedObjectName source)

private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Optional<TableHandle> tableHandle, List<Field> fields, String authorization)
{
analyzeFiltersAndMasks(table, name, tableHandle, new RelationType(fields), authorization);
analyzeFiltersAndMasks(table, name, tableHandle, new RelationType(fields), authorization, ImmutableList.of());
}

private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Optional<TableHandle> tableHandle, RelationType relationType, String authorization)
private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Optional<TableHandle> tableHandle, RelationType relationType, String authorization, List<String> checkConstraints)
{
Scope accessControlScope = Scope.builder()
.withRelationType(RelationId.anonymous(), relationType)
Expand All @@ -2165,8 +2166,15 @@ private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Optio
}
}

accessControl.getRowFilters(session.toSecurityContext(), name)
.forEach(filter -> analyzeRowFilter(session.getIdentity().getUser(), table, name, accessControlScope, filter));
for (ViewExpression filter : accessControl.getRowFilters(session.toSecurityContext(), name)) {
Expression expression = analyzeRowFilter(session.getIdentity().getUser(), table, name, accessControlScope, filter);
analysis.addRowFilter(table, expression);
}
for (String checkConstraint : checkConstraints) {
ViewExpression filter = new ViewExpression(session.getIdentity().getUser(), Optional.of(name.getCatalogName()), Optional.of(name.getSchemaName()), checkConstraint);
Expression expression = analyzeRowFilter(session.getIdentity().getUser(), table, name, accessControlScope, filter);
analysis.addCheckConstraints(table, expression);
}

analysis.registerTable(table, tableHandle, name, authorization, accessControlScope);
}
Expand Down Expand Up @@ -4491,7 +4499,7 @@ private ExpressionAnalysis analyzeExpression(Expression expression, Scope scope,
correlationSupport);
}

private void analyzeRowFilter(String currentIdentity, Table table, QualifiedObjectName name, Scope scope, ViewExpression filter)
private Expression analyzeRowFilter(String currentIdentity, Table table, QualifiedObjectName name, Scope scope, ViewExpression filter)
{
if (analysis.hasRowFilter(name, currentIdentity)) {
throw new TrinoException(INVALID_ROW_FILTER, extractLocation(table), format("Row filter for '%s' is recursive", name), null);
Expand Down Expand Up @@ -4542,7 +4550,7 @@ private void analyzeRowFilter(String currentIdentity, Table table, QualifiedObje
analysis.addCoercion(expression, BOOLEAN, coercion.isTypeOnlyCoercion(actualType, BOOLEAN));
}

analysis.addRowFilter(table, expression);
return expression;
}

private void analyzeColumnMask(String currentIdentity, Table table, QualifiedObjectName tableName, Field field, Scope scope, ViewExpression mask)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,18 +486,8 @@ private RelationPlan getInsertPlan(

plan = new RelationPlan(projectNode, scope, projectNode.getOutputSymbols(), Optional.empty());

plan = planner.addRowFilters(
table,
plan,
failIfPredicateIsNotMet(metadata, session, PERMISSION_DENIED, AccessDeniedException.PREFIX + "Cannot insert row that does not match to a row filter"),
node -> {
Scope accessControlScope = analysis.getAccessControlScope(table);
// hidden fields are not accessible in insert
return Scope.builder()
.like(accessControlScope)
.withRelationType(accessControlScope.getRelationId(), accessControlScope.getRelationType().withOnlyVisibleFields())
.build();
});
plan = addRowFilters(analysis, planner, table, plan, failIfPredicateIsNotMet(metadata, session, PERMISSION_DENIED, AccessDeniedException.PREFIX + "Cannot insert row that does not match to a row filter"), analysis.getRowFilters(table));
plan = addRowFilters(analysis, planner, table, plan, failIfCheckConstraintIsNotMet(metadata, session), analysis.getCheckConstraints(table));

List<String> insertedTableColumnNames = insertedColumns.stream()
.map(ColumnMetadata::getName)
Expand Down Expand Up @@ -530,6 +520,23 @@ private RelationPlan getInsertPlan(
statisticsMetadata);
}

private RelationPlan addRowFilters(Analysis analysis, RelationPlanner planner, Table table, RelationPlan plan, Function<Expression, Expression> predicateTransformation, List<Expression> filters)
{
return planner.addRowFilters(
filters,
table,
plan,
predicateTransformation,
node -> {
Scope accessControlScope = analysis.getAccessControlScope(table);
// hidden fields are not accessible in insert
return Scope.builder()
.like(accessControlScope)
.withRelationType(accessControlScope.getRelationId(), accessControlScope.getRelationType().withOnlyVisibleFields())
.build();
});
}

private Expression createNullNotAllowedFailExpression(String columnName, Type type)
{
return new Cast(failFunction(metadata, session, CONSTRAINT_VIOLATION, format(
Expand All @@ -543,6 +550,11 @@ private static Function<Expression, Expression> failIfPredicateIsNotMet(Metadata
return predicate -> new IfExpression(predicate, TRUE_LITERAL, new Cast(fail, toSqlType(BOOLEAN)));
}

private static Function<Expression, Expression> failIfCheckConstraintIsNotMet(Metadata metadata, Session session)
{
return predicate -> new IfExpression(predicate, TRUE_LITERAL, new Cast(failFunction(metadata, session, CONSTRAINT_VIOLATION, "Cannot insert row that does not match to a check constraint: " + predicate), toSqlType(BOOLEAN)));
}

public static FunctionCall failFunction(Metadata metadata, Session session, ErrorCodeSupplier errorCode, String errorMessage)
{
return FunctionCallBuilder.resolve(session, metadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,11 @@ public RelationPlan addRowFilters(Table node, RelationPlan plan, Function<Expres
public RelationPlan addRowFilters(Table node, RelationPlan plan, Function<Expression, Expression> predicateTransformation, Function<Table, Scope> accessControlScope)
{
List<Expression> filters = analysis.getRowFilters(node);
return addRowFilters(filters, node, plan, predicateTransformation, accessControlScope);
}

public RelationPlan addRowFilters(List<Expression> filters, Table node, RelationPlan plan, Function<Expression, Expression> predicateTransformation, Function<Table, Scope> accessControlScope)
{
if (filters.isEmpty()) {
return plan;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ public class MockConnector
private final BiFunction<ConnectorSession, SchemaTableName, CompletableFuture<?>> refreshMaterializedView;
private final BiFunction<ConnectorSession, SchemaTableName, ConnectorTableHandle> getTableHandle;
private final Function<SchemaTableName, List<ColumnMetadata>> getColumns;
private final Function<SchemaTableName, List<String>> checkConstraints;
private final MockConnectorFactory.ApplyProjection applyProjection;
private final MockConnectorFactory.ApplyAggregation applyAggregation;
private final MockConnectorFactory.ApplyJoin applyJoin;
Expand Down Expand Up @@ -170,6 +171,7 @@ public class MockConnector
BiFunction<ConnectorSession, SchemaTableName, CompletableFuture<?>> refreshMaterializedView,
BiFunction<ConnectorSession, SchemaTableName, ConnectorTableHandle> getTableHandle,
Function<SchemaTableName, List<ColumnMetadata>> getColumns,
Function<SchemaTableName, List<String>> checkConstraints,
ApplyProjection applyProjection,
ApplyAggregation applyAggregation,
ApplyJoin applyJoin,
Expand Down Expand Up @@ -206,6 +208,7 @@ public class MockConnector
this.refreshMaterializedView = requireNonNull(refreshMaterializedView, "refreshMaterializedView is null");
this.getTableHandle = requireNonNull(getTableHandle, "getTableHandle is null");
this.getColumns = requireNonNull(getColumns, "getColumns is null");
this.checkConstraints = requireNonNull(checkConstraints, "checkConstraints is null");
this.applyProjection = requireNonNull(applyProjection, "applyProjection is null");
this.applyAggregation = requireNonNull(applyAggregation, "applyAggregation is null");
this.applyJoin = requireNonNull(applyJoin, "applyJoin is null");
Expand Down Expand Up @@ -441,7 +444,12 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable
public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle)
{
MockConnectorTableHandle table = (MockConnectorTableHandle) tableHandle;
return new ConnectorTableMetadata(table.getTableName(), getColumns.apply(table.getTableName()));
return new ConnectorTableMetadata(
table.getTableName(),
getColumns.apply(table.getTableName()),
ImmutableMap.of(),
Optional.empty(),
checkConstraints.apply(table.getTableName()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ public class MockConnectorFactory
private final BiFunction<ConnectorSession, SchemaTableName, CompletableFuture<?>> refreshMaterializedView;
private final BiFunction<ConnectorSession, SchemaTableName, ConnectorTableHandle> getTableHandle;
private final Function<SchemaTableName, List<ColumnMetadata>> getColumns;
private final Function<SchemaTableName, List<String>> checkConstraints;
private final ApplyProjection applyProjection;
private final ApplyAggregation applyAggregation;
private final ApplyJoin applyJoin;
Expand Down Expand Up @@ -130,6 +131,7 @@ private MockConnectorFactory(
BiFunction<ConnectorSession, SchemaTableName, CompletableFuture<?>> refreshMaterializedView,
BiFunction<ConnectorSession, SchemaTableName, ConnectorTableHandle> getTableHandle,
Function<SchemaTableName, List<ColumnMetadata>> getColumns,
Function<SchemaTableName, List<String>> checkConstraints,
ApplyProjection applyProjection,
ApplyAggregation applyAggregation,
ApplyJoin applyJoin,
Expand Down Expand Up @@ -167,6 +169,7 @@ private MockConnectorFactory(
this.refreshMaterializedView = requireNonNull(refreshMaterializedView, "refreshMaterializedView is null");
this.getTableHandle = requireNonNull(getTableHandle, "getTableHandle is null");
this.getColumns = requireNonNull(getColumns, "getColumns is null");
this.checkConstraints = requireNonNull(checkConstraints, "checkConstraints is null");
this.applyProjection = requireNonNull(applyProjection, "applyProjection is null");
this.applyAggregation = requireNonNull(applyAggregation, "applyAggregation is null");
this.applyJoin = requireNonNull(applyJoin, "applyJoin is null");
Expand Down Expand Up @@ -214,6 +217,7 @@ public Connector create(String catalogName, Map<String, String> config, Connecto
refreshMaterializedView,
getTableHandle,
getColumns,
checkConstraints,
applyProjection,
applyAggregation,
applyJoin,
Expand Down Expand Up @@ -338,6 +342,7 @@ public static final class Builder
private BiFunction<ConnectorSession, SchemaTableName, CompletableFuture<?>> refreshMaterializedView = (session, viewName) -> CompletableFuture.completedFuture(null);
private BiFunction<ConnectorSession, SchemaTableName, ConnectorTableHandle> getTableHandle = defaultGetTableHandle();
private Function<SchemaTableName, List<ColumnMetadata>> getColumns = defaultGetColumns();
private Function<SchemaTableName, List<String>> checkConstraints = (schemaTableName -> ImmutableList.of());
private ApplyProjection applyProjection = (session, handle, projections, assignments) -> Optional.empty();
private ApplyAggregation applyAggregation = (session, handle, aggregates, assignments, groupingSets) -> Optional.empty();
private ApplyJoin applyJoin = (session, joinType, left, right, joinConditions, leftAssignments, rightAssignments) -> Optional.empty();
Expand Down Expand Up @@ -451,6 +456,12 @@ public Builder withGetColumns(Function<SchemaTableName, List<ColumnMetadata>> ge
return this;
}

public Builder withCheckConstraints(Function<SchemaTableName, List<String>> checkConstraints)
{
this.checkConstraints = requireNonNull(checkConstraints, "checkConstraints is null");
return this;
}

public Builder withApplyProjection(ApplyProjection applyProjection)
{
this.applyProjection = applyProjection;
Expand Down Expand Up @@ -647,6 +658,7 @@ public MockConnectorFactory build()
refreshMaterializedView,
getTableHandle,
getColumns,
checkConstraints,
applyProjection,
applyAggregation,
applyJoin,
Expand Down
Loading