diff --git a/core/trino-main/src/main/java/io/trino/security/AccessControl.java b/core/trino-main/src/main/java/io/trino/security/AccessControl.java index b398dd544605..3a80a40d45a7 100644 --- a/core/trino-main/src/main/java/io/trino/security/AccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/AccessControl.java @@ -566,8 +566,8 @@ default List getRowFilters(SecurityContext context, QualifiedObj return ImmutableList.of(); } - default List getColumnMasks(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type) + default Optional getColumnMask(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type) { - return ImmutableList.of(); + return Optional.empty(); } } diff --git a/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java b/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java index b1ff99897627..373405add53e 100644 --- a/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java +++ b/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java @@ -72,6 +72,7 @@ import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.configuration.ConfigurationLoader.loadPropertiesFrom; +import static io.trino.spi.StandardErrorCode.INVALID_COLUMN_MASK; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.StandardErrorCode.SERVER_STARTING_UP; import static java.lang.String.format; @@ -1248,26 +1249,30 @@ public List getRowFilters(SecurityContext context, QualifiedObje } @Override - public List getColumnMasks(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type) + public Optional getColumnMask(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type) { requireNonNull(context, "context is null"); requireNonNull(tableName, "tableName is null"); ImmutableList.Builder masks = ImmutableList.builder(); - // connector-provided masks take precedence over global masks ConnectorAccessControl connectorAccessControl = getConnectorAccessControl(context.getTransactionId(), tableName.getCatalogName()); if (connectorAccessControl != null) { - connectorAccessControl.getColumnMasks(toConnectorSecurityContext(tableName.getCatalogName(), context), tableName.asSchemaTableName(), columnName, type) - .forEach(masks::add); + connectorAccessControl.getColumnMask(toConnectorSecurityContext(tableName.getCatalogName(), context), tableName.asSchemaTableName(), columnName, type) + .ifPresent(masks::add); } for (SystemAccessControl systemAccessControl : getSystemAccessControls()) { - systemAccessControl.getColumnMasks(context.toSystemSecurityContext(), tableName.asCatalogSchemaTableName(), columnName, type) - .forEach(masks::add); + systemAccessControl.getColumnMask(context.toSystemSecurityContext(), tableName.asCatalogSchemaTableName(), columnName, type) + .ifPresent(masks::add); } - return masks.build(); + List allMasks = masks.build(); + if (allMasks.size() > 1) { + throw new TrinoException(INVALID_COLUMN_MASK, format("Column must have a single mask: %s", columnName)); + } + + return allMasks.stream().findFirst(); } private ConnectorAccessControl getConnectorAccessControl(TransactionId transactionId, String catalogName) diff --git a/core/trino-main/src/main/java/io/trino/security/ForwardingAccessControl.java b/core/trino-main/src/main/java/io/trino/security/ForwardingAccessControl.java index 1c44eb89c378..89a54ca73941 100644 --- a/core/trino-main/src/main/java/io/trino/security/ForwardingAccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/ForwardingAccessControl.java @@ -486,8 +486,8 @@ public List getRowFilters(SecurityContext context, QualifiedObje } @Override - public List getColumnMasks(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type) + public Optional getColumnMask(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type) { - return delegate().getColumnMasks(context, tableName, columnName, type); + return delegate().getColumnMask(context, tableName, columnName, type); } } diff --git a/core/trino-main/src/main/java/io/trino/security/InjectedConnectorAccessControl.java b/core/trino-main/src/main/java/io/trino/security/InjectedConnectorAccessControl.java index 76da1a5414c3..890a3276dfdf 100644 --- a/core/trino-main/src/main/java/io/trino/security/InjectedConnectorAccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/InjectedConnectorAccessControl.java @@ -492,15 +492,21 @@ public List getRowFilters(ConnectorSecurityContext context, Sche } @Override - public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + public Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { checkArgument(context == null, "context must be null"); - if (accessControl.getColumnMasks(securityContext, new QualifiedObjectName(catalogName, tableName.getSchemaName(), tableName.getTableName()), columnName, type).isEmpty()) { - return ImmutableList.of(); + if (accessControl.getColumnMask(securityContext, new QualifiedObjectName(catalogName, tableName.getSchemaName(), tableName.getTableName()), columnName, type).isEmpty()) { + return Optional.empty(); } throw new TrinoException(NOT_SUPPORTED, "Column masking not supported"); } + @Override + public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + throw new UnsupportedOperationException(); + } + private QualifiedObjectName getQualifiedObjectName(SchemaTableName schemaTableName) { return new QualifiedObjectName(catalogName, schemaTableName.getSchemaName(), schemaTableName.getTableName()); diff --git a/core/trino-main/src/main/java/io/trino/security/ViewAccessControl.java b/core/trino-main/src/main/java/io/trino/security/ViewAccessControl.java index f2b13cbf0c65..6e61a5a950cd 100644 --- a/core/trino-main/src/main/java/io/trino/security/ViewAccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/ViewAccessControl.java @@ -22,6 +22,7 @@ import io.trino.spi.type.Type; import java.util.List; +import java.util.Optional; import java.util.Set; import static com.google.common.base.Verify.verify; @@ -92,9 +93,9 @@ public List getRowFilters(SecurityContext context, QualifiedObje } @Override - public List getColumnMasks(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type) + public Optional getColumnMask(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type) { - return delegate.getColumnMasks(context, tableName, columnName, type); + return delegate.getColumnMask(context, tableName, columnName, type); } private static void wrapAccessDeniedException(Runnable runnable) diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index 474ec713a9a2..3c80193e8134 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -214,7 +214,7 @@ public class Analysis private final Map, List> rowFilters = new LinkedHashMap<>(); private final Multiset columnMaskScopes = HashMultiset.create(); - private final Map, Map>> columnMasks = new LinkedHashMap<>(); + private final Map, Map> columnMasks = new LinkedHashMap<>(); private final Map, UnnestAnalysis> unnestAnalysis = new LinkedHashMap<>(); private Optional create = Optional.empty(); @@ -1093,12 +1093,13 @@ public void unregisterTableForColumnMasking(QualifiedObjectName table, String co public void addColumnMask(Table table, String column, Expression mask) { - Map> masks = columnMasks.computeIfAbsent(NodeRef.of(table), node -> new LinkedHashMap<>()); - masks.computeIfAbsent(column, name -> new ArrayList<>()) - .add(mask); + Map masks = columnMasks.computeIfAbsent(NodeRef.of(table), node -> new LinkedHashMap<>()); + checkArgument(!masks.containsKey(column), "Mask already exists for column %s", column); + + masks.put(column, mask); } - public Map> getColumnMasks(Table table) + public Map getColumnMasks(Table table) { return columnMasks.getOrDefault(NodeRef.of(table), ImmutableMap.of()); } @@ -1118,10 +1119,8 @@ public List getReferencedTables() .distinct() .map(fieldName -> new ColumnInfo( fieldName, - columnMasks.getOrDefault(table, ImmutableMap.of()) - .getOrDefault(fieldName, ImmutableList.of()).stream() - .map(Expression::toString) - .collect(toImmutableList()))) + Optional.ofNullable(columnMasks.getOrDefault(table, ImmutableMap.of()).get(fieldName)) + .map(Expression::toString))) .collect(toImmutableList()); TableEntry info = entry.getValue(); diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 5e597f29d26a..2e086469e08d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -544,7 +544,7 @@ protected Scope visitInsert(Insert insert, Optional scope) .collect(toImmutableList()); for (ColumnSchema column : columns) { - if (!accessControl.getColumnMasks(session.toSecurityContext(), targetTable, column.getName(), column.getType()).isEmpty()) { + if (accessControl.getColumnMask(session.toSecurityContext(), targetTable, column.getName(), column.getType()).isPresent()) { throw semanticException(NOT_SUPPORTED, insert, "Insert into table with column masks is not supported"); } } @@ -785,7 +785,7 @@ protected Scope visitDelete(Delete node, Optional scope) TableSchema tableSchema = metadata.getTableSchema(session, handle); for (ColumnSchema tableColumn : tableSchema.getColumns()) { - if (!accessControl.getColumnMasks(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isEmpty()) { + if (accessControl.getColumnMask(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isPresent()) { throw semanticException(NOT_SUPPORTED, node, "Delete from table with column mask"); } } @@ -1149,7 +1149,7 @@ protected Scope visitTableExecute(TableExecute node, Optional scope) TableMetadata tableMetadata = metadata.getTableMetadata(session, tableHandle); for (ColumnMetadata tableColumn : tableMetadata.getColumns()) { - if (!accessControl.getColumnMasks(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isEmpty()) { + if (accessControl.getColumnMask(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isPresent()) { throw semanticException(NOT_SUPPORTED, node, "ALTER TABLE EXECUTE is not supported for table with column masks"); } } @@ -2222,10 +2222,10 @@ private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Optio for (int index = 0; index < relationType.getAllFieldCount(); index++) { Field field = relationType.getFieldByIndex(index); if (field.getName().isPresent()) { - List masks = accessControl.getColumnMasks(session.toSecurityContext(), name, field.getName().get(), field.getType()); + Optional mask = accessControl.getColumnMask(session.toSecurityContext(), name, field.getName().get(), field.getType()); - if (!masks.isEmpty() && checkCanSelectFromColumn(name, field.getName().orElseThrow())) { - masks.forEach(mask -> analyzeColumnMask(session.getIdentity().getUser(), table, name, field, accessControlScope, mask)); + if (mask.isPresent() && checkCanSelectFromColumn(name, field.getName().orElseThrow())) { + analyzeColumnMask(session.getIdentity().getUser(), table, name, field, accessControlScope, mask.get()); } } } @@ -3178,7 +3178,7 @@ protected Scope visitUpdate(Update update, Optional scope) // TODO: how to deal with connectors that need to see the pre-image of rows to perform the update without // flowing that data through the masking logic for (ColumnSchema tableColumn : allColumns) { - if (!accessControl.getColumnMasks(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isEmpty()) { + if (accessControl.getColumnMask(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isPresent()) { throw semanticException(NOT_SUPPORTED, update, "Updating a table with column masks is not supported"); } } @@ -3307,7 +3307,7 @@ protected Scope visitMerge(Merge merge, Optional scope) Scope joinScope = createAndAssignScope(merge, scope, targetTableScope.getRelationType().joinWith(sourceTableScope.getRelationType())); for (ColumnSchema column : dataColumnSchemas) { - if (!accessControl.getColumnMasks(session.toSecurityContext(), tableName, column.getName(), column.getType()).isEmpty()) { + if (accessControl.getColumnMask(session.toSecurityContext(), tableName, column.getName(), column.getType()).isPresent()) { throw semanticException(NOT_SUPPORTED, merge, "Cannot merge into a table with column masks"); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 3c0e138d4dc9..d3e7c56fb3dd 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -100,7 +100,6 @@ import java.util.ArrayList; import java.util.HashMap; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -293,7 +292,7 @@ public RelationPlan addRowFilters(Table node, RelationPlan plan, Function> columnMasks = analysis.getColumnMasks(table); + Map columnMasks = analysis.getColumnMasks(table); // A Table can represent a WITH query, which can have anonymous fields. On the other hand, // it can't have masks. The loop below expects fields to have proper names, so bail out @@ -305,27 +304,33 @@ private RelationPlan addColumnMasks(Table table, RelationPlan plan) PlanBuilder planBuilder = newPlanBuilder(plan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext) .withScope(analysis.getAccessControlScope(table), plan.getFieldMappings()); // The fields in the access control scope has the same layout as those for the table scope + Assignments.Builder assignments = Assignments.builder(); + assignments.putIdentities(planBuilder.getRoot().getOutputSymbols()); + + List fieldMappings = new ArrayList<>(); for (int i = 0; i < plan.getDescriptor().getAllFieldCount(); i++) { Field field = plan.getDescriptor().getFieldByIndex(i); - for (Expression mask : columnMasks.getOrDefault(field.getName().orElseThrow(), ImmutableList.of())) { + Expression mask = columnMasks.get(field.getName().orElseThrow()); + Symbol symbol = plan.getFieldMappings().get(i); + Expression projection = symbol.toSymbolReference(); + if (mask != null) { planBuilder = subqueryPlanner.handleSubqueries(planBuilder, mask, analysis.getSubqueries(mask)); - - Map assignments = new LinkedHashMap<>(); - for (Symbol symbol : planBuilder.getRoot().getOutputSymbols()) { - assignments.put(symbol, symbol.toSymbolReference()); - } - assignments.put(plan.getFieldMappings().get(i), coerceIfNecessary(analysis, mask, planBuilder.rewrite(mask))); - - planBuilder = planBuilder - .withNewRoot(new ProjectNode( - idAllocator.getNextId(), - planBuilder.getRoot(), - Assignments.copyOf(assignments))); + symbol = symbolAllocator.newSymbol(symbol); + projection = coerceIfNecessary(analysis, mask, planBuilder.rewrite(mask)); } + + assignments.put(symbol, projection); + fieldMappings.add(symbol); } - return new RelationPlan(planBuilder.getRoot(), plan.getScope(), plan.getFieldMappings(), outerContext); + planBuilder = planBuilder + .withNewRoot(new ProjectNode( + idAllocator.getNextId(), + planBuilder.getRoot(), + assignments.build())); + + return new RelationPlan(planBuilder.getRoot(), plan.getScope(), fieldMappings, outerContext); } @Override diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java b/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java index fea2b51d42ad..36d18a9c4fda 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java @@ -135,7 +135,7 @@ public class TestingAccessControlManager private final Set denyPrivileges = new HashSet<>(); private final Map> rowFilters = new HashMap<>(); - private final Map> columnMasks = new HashMap<>(); + private final Map columnMasks = new HashMap<>(); private Predicate deniedCatalogs = s -> true; private Predicate deniedSchemas = s -> true; private Predicate deniedTables = s -> true; @@ -175,8 +175,7 @@ public void rowFilter(QualifiedObjectName table, String identity, ViewExpression public void columnMask(QualifiedObjectName table, String column, String identity, ViewExpression mask) { - columnMasks.computeIfAbsent(new ColumnMaskKey(identity, table, column), key -> new ArrayList<>()) - .add(mask); + columnMasks.put(new ColumnMaskKey(identity, table, column), mask); } public void reset() @@ -746,13 +745,13 @@ public List getRowFilters(SecurityContext context, QualifiedObje } @Override - public List getColumnMasks(SecurityContext context, QualifiedObjectName tableName, String column, Type type) + public Optional getColumnMask(SecurityContext context, QualifiedObjectName tableName, String column, Type type) { - List viewExpressions = columnMasks.get(new ColumnMaskKey(context.getIdentity().getUser(), tableName, column)); - if (viewExpressions != null) { - return viewExpressions; + ViewExpression mask = columnMasks.get(new ColumnMaskKey(context.getIdentity().getUser(), tableName, column)); + if (mask != null) { + return Optional.of(mask); } - return super.getColumnMasks(context, tableName, column, type); + return super.getColumnMask(context, tableName, column, type); } private boolean shouldDenyPrivilege(String actorName, String entityName, TestingPrivilegeType verb) diff --git a/core/trino-main/src/test/java/io/trino/connector/MockConnectorAccessControl.java b/core/trino-main/src/test/java/io/trino/connector/MockConnectorAccessControl.java index dd50b7d68311..bc763bcbd88d 100644 --- a/core/trino-main/src/test/java/io/trino/connector/MockConnectorAccessControl.java +++ b/core/trino-main/src/test/java/io/trino/connector/MockConnectorAccessControl.java @@ -129,12 +129,16 @@ public List getRowFilters(ConnectorSecurityContext context, Sche .orElseGet(ImmutableList::of); } + @Override + public Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + return Optional.ofNullable(columnMasks.apply(tableName, columnName)); + } + @Override public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { - return Optional.ofNullable(columnMasks.apply(tableName, columnName)) - .map(ImmutableList::of) - .orElseGet(ImmutableList::of); + throw new UnsupportedOperationException(); } public void grantSchemaPrivileges(String schemaName, Set privileges, TrinoPrincipal grantee, boolean grantOption) diff --git a/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java b/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java index 8b3964568014..07ce6c40feca 100644 --- a/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java +++ b/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java @@ -58,7 +58,6 @@ import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.spi.function.FunctionKind.TABLE; import static io.trino.spi.security.AccessDeniedException.denySelectTable; -import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.testing.TestingEventListenerManager.emptyEventListenerManager; import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME; import static io.trino.transaction.InMemoryTransactionManager.createTestTransactionManager; @@ -258,17 +257,6 @@ public void checkCanShowCreateTable(ConnectorSecurityContext context, SchemaTabl { } }))); - - transaction(transactionManager, accessControlManager) - .execute(transactionId -> { - List masks = accessControlManager.getColumnMasks( - context(transactionId), - new QualifiedObjectName(TEST_CATALOG_NAME, "schema", "table"), - "column", - BIGINT); - assertEquals(masks.get(0).getExpression(), "connector mask"); - assertEquals(masks.get(1).getExpression(), "system mask"); - }); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java b/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java index 9b7078b05f7f..3b268a6a4462 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java @@ -26,6 +26,7 @@ import io.trino.spi.security.Identity; import io.trino.spi.security.ViewExpression; import io.trino.spi.type.BigintType; +import io.trino.spi.type.RowType; import io.trino.spi.type.VarcharType; import io.trino.testing.LocalQueryRunner; import io.trino.testing.TestingAccessControlManager; @@ -41,6 +42,7 @@ import static io.trino.connector.MockConnectorEntities.TPCH_NATION_WITH_HIDDEN_COLUMN; import static io.trino.connector.MockConnectorEntities.TPCH_WITH_HIDDEN_COLUMN_DATA; import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.SELECT_COLUMN; import static io.trino.testing.TestingAccessControlManager.privilege; import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME; @@ -86,6 +88,26 @@ public void init() Optional.of(VIEW_OWNER), false); + ConnectorViewDefinition viewWithNested = new ConnectorViewDefinition( + """ + SELECT * FROM ( + VALUES + ROW(ROW(1,2), 0), + ROW(ROW(3,4), 1) + ) t(nested, id) + """, + Optional.empty(), + Optional.empty(), + ImmutableList.of( + new ConnectorViewDefinition.ViewColumn("nested", RowType.from(ImmutableList.of( + RowType.field(INTEGER), + RowType.field(INTEGER))).getTypeId(), + Optional.empty()), + new ConnectorViewDefinition.ViewColumn("id", INTEGER.getTypeId(), Optional.empty())), + Optional.empty(), + Optional.of(VIEW_OWNER), + false); + ConnectorMaterializedViewDefinition materializedView = new ConnectorMaterializedViewDefinition( "SELECT * FROM local.tiny.nation", Optional.empty(), @@ -142,7 +164,8 @@ public void init() throw new UnsupportedOperationException(); }) .withGetViews((s, prefix) -> ImmutableMap.of( - new SchemaTableName("default", "nation_view"), view)) + new SchemaTableName("default", "nation_view"), view, + new SchemaTableName("default", "view_with_nested"), viewWithNested)) .withGetMaterializedViews((s, prefix) -> ImmutableMap.of( new SchemaTableName("default", "nation_materialized_view"), materializedView, new SchemaTableName("default", "nation_fresh_materialized_view"), freshMaterializedView, @@ -196,25 +219,6 @@ public void testConditionalMask() .matches("VALUES (NULL), CAST('-781' AS BIGINT)"); } - @Test - public void testMultipleMasksOnSameColumn() - { - accessControl.reset(); - accessControl.columnMask( - new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), - "custkey", - USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "-custkey")); - - accessControl.columnMask( - new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), - "custkey", - USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "custkey * 2")); - - assertThat(assertions.query("SELECT custkey FROM orders WHERE orderkey = 1")).matches("VALUES BIGINT '-740'"); - } - @Test public void testMultipleMasksOnDifferentColumns() { @@ -328,7 +332,7 @@ public void testMaterializedView() .setIdentity(Identity.forUser(USER).build()) .build(), "SELECT name FROM mock.default.materialized_view_with_casts WHERE nationkey = 1")) - .matches("VALUES CAST('RA' AS VARCHAR(2))"); + .matches("VALUES 'RA'"); } @Test @@ -614,22 +618,6 @@ public void testJoin() USER, new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey + 1")); assertThat(assertions.query("SELECT count(*) FROM orders JOIN orders USING (orderkey)")).matches("VALUES BIGINT '15000'"); - - // multiple masks - accessControl.reset(); - accessControl.columnMask( - new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), - "orderkey", - USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "-orderkey")); - - accessControl.columnMask( - new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), - "orderkey", - USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey * 2")); - - assertThat(assertions.query("SELECT count(*) FROM orders JOIN orders USING (orderkey)")).matches("VALUES BIGINT '15000'"); } @Test @@ -847,7 +835,7 @@ public void testMultipleMasksUsingOtherMaskedColumns() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(Clerk#)','***#') as varchar(15))")); + new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast('###' as varchar(15))")); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), @@ -862,6 +850,20 @@ public void testMultipleMasksUsingOtherMaskedColumns() new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('951'), '***', comment)")); assertThat(assertions.query(query)) - .matches("VALUES (CAST('***' as varchar(79)), '*', CAST('***#000000951' as varchar(15)))"); + .matches("VALUES (CAST('***' as varchar(79)), '*', CAST('###' as varchar(15)))"); + } + + @Test + public void testColumnAliasing() + { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(MOCK_CATALOG, "default", "view_with_nested"), + "nested", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(id = 0, nested)")); + + assertThat(assertions.query("SELECT nested[1] FROM mock.default.view_with_nested")) + .matches("VALUES 1, NULL"); } } diff --git a/core/trino-spi/pom.xml b/core/trino-spi/pom.xml index 2875bfae937b..96b84c17f956 100644 --- a/core/trino-spi/pom.xml +++ b/core/trino-spi/pom.xml @@ -233,6 +233,20 @@ package It was not accessible outside of SPI anyway + + true + java.method.parameterTypeChanged + parameter void io.trino.spi.eventlistener.ColumnInfo::<init>(java.lang.String, ===java.util.List<java.lang.String>===) + parameter void io.trino.spi.eventlistener.ColumnInfo::<init>(java.lang.String, ===java.util.Optional<java.lang.String>===) + 1 + Removing support for multiple masks on a given column, as they are error prone + + + true + java.method.removed + method java.util.List<java.lang.String> io.trino.spi.eventlistener.ColumnInfo::getMasks() + Removing support for multiple masks on a given column, as they are error prone + diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorAccessControl.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorAccessControl.java index 6e2320b7d5a3..585f270a1ca4 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorAccessControl.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorAccessControl.java @@ -659,13 +659,24 @@ default List getRowFilters(ConnectorSecurityContext context, Sch } /** - * Get column masks associated with the given table, column and identity. + * Get column mask associated with the given table, column and identity. *

- * Each mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression + * The mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression * must be written in terms of columns in the table. * - * @return the list of masks, or empty list if not applicable + * @return the mask if present, or empty if not applicable */ + default Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + List masks = getColumnMasks(context, tableName, columnName, type); + if (masks.size() > 1) { + throw new UnsupportedOperationException("Multiple masks on a single column are no longer supported"); + } + + return masks.stream().findFirst(); + } + + @Deprecated default List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { return emptyList(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/eventlistener/ColumnInfo.java b/core/trino-spi/src/main/java/io/trino/spi/eventlistener/ColumnInfo.java index 4d26f514033a..23adca12b610 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/eventlistener/ColumnInfo.java +++ b/core/trino-spi/src/main/java/io/trino/spi/eventlistener/ColumnInfo.java @@ -17,7 +17,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import io.trino.spi.Unstable; -import java.util.List; +import java.util.Optional; /** * This class is JSON serializable for convenience and serialization compatibility is not guaranteed across versions. @@ -25,14 +25,14 @@ public class ColumnInfo { private final String column; - private final List masks; + private final Optional mask; @JsonCreator @Unstable - public ColumnInfo(String column, List masks) + public ColumnInfo(String column, Optional mask) { this.column = column; - this.masks = masks; + this.mask = mask; } @JsonProperty @@ -42,8 +42,8 @@ public String getColumn() } @JsonProperty - public List getMasks() + public Optional getMask() { - return masks; + return mask; } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/security/SystemAccessControl.java b/core/trino-spi/src/main/java/io/trino/spi/security/SystemAccessControl.java index 77df60d14a1d..1b0df6c1fba9 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/security/SystemAccessControl.java +++ b/core/trino-spi/src/main/java/io/trino/spi/security/SystemAccessControl.java @@ -877,11 +877,22 @@ default List getRowFilters(SystemSecurityContext context, Catalo /** * Get column masks associated with the given table, column and identity. *

- * Each mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression + * The mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression * must be written in terms of columns in the table. * - * @return the list of masks, or empty list if not applicable + * @return the mask if present, or empty if not applicable */ + default Optional getColumnMask(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type) + { + List masks = getColumnMasks(context, tableName, columnName, type); + if (masks.size() > 1) { + throw new UnsupportedOperationException("Multiple masks on a single column are no longer supported"); + } + + return masks.stream().findFirst(); + } + + @Deprecated default List getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type) { return List.of(); diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorAccessControl.java index 021c60b61088..1cde37af1469 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorAccessControl.java @@ -534,6 +534,14 @@ public List getRowFilters(ConnectorSecurityContext context, Sche } } + @Override + public Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getColumnMask(context, tableName, columnName, type); + } + } + @Override public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllAccessControl.java index be94198f2a7e..00b8d4ed792e 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllAccessControl.java @@ -344,6 +344,12 @@ public List getRowFilters(ConnectorSecurityContext context, Sche return ImmutableList.of(); } + @Override + public Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + return Optional.empty(); + } + @Override public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllSystemAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllSystemAccessControl.java index 53b9abb364ba..1b71496a1e64 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllSystemAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllSystemAccessControl.java @@ -461,6 +461,12 @@ public List getRowFilters(SystemSecurityContext context, Catalog return emptyList(); } + @Override + public Optional getColumnMask(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type) + { + return Optional.empty(); + } + @Override public List getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type) { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java index 4b113e5ebbf1..c4ce581c83cb 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableSet; import io.trino.plugin.base.CatalogName; import io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege; +import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSecurityContext; import io.trino.spi.connector.SchemaRoutineName; @@ -44,6 +45,7 @@ import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.OWNERSHIP; import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.SELECT; import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.UPDATE; +import static io.trino.spi.StandardErrorCode.INVALID_COLUMN_MASK; import static io.trino.spi.security.AccessDeniedException.denyAddColumn; import static io.trino.spi.security.AccessDeniedException.denyAlterColumn; import static io.trino.spi.security.AccessDeniedException.denyCommentColumn; @@ -654,21 +656,33 @@ public List getRowFilters(ConnectorSecurityContext context, Sche } @Override - public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + public Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { if (INFORMATION_SCHEMA_NAME.equals(tableName.getSchemaName())) { - return ImmutableList.of(); + return Optional.empty(); } ConnectorIdentity identity = context.getIdentity(); - return tableRules.stream() + List masks = tableRules.stream() .filter(rule -> rule.matches(identity.getUser(), identity.getEnabledSystemRoles(), identity.getGroups(), tableName)) .map(rule -> rule.getColumnMask(identity.getUser(), catalogName, tableName.getSchemaName(), columnName)) // we return the first one we find .findFirst() .stream() .flatMap(Optional::stream) - .collect(toImmutableList()); + .toList(); + + if (masks.size() > 1) { + throw new TrinoException(INVALID_COLUMN_MASK, format("Multiple masks defined for %s.%s", tableName, columnName)); + } + + return masks.stream().findFirst(); + } + + @Override + public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + throw new UnsupportedOperationException(); } private boolean canSetSessionProperty(ConnectorSecurityContext context, String property) diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java index c59ab7368210..f208ac5be4ee 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java @@ -19,6 +19,7 @@ import io.airlift.bootstrap.Bootstrap; import io.trino.plugin.base.security.CatalogAccessControlRule.AccessMode; import io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege; +import io.trino.spi.TrinoException; import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.connector.CatalogSchemaRoutineName; import io.trino.spi.connector.CatalogSchemaTableName; @@ -52,6 +53,7 @@ import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.OWNERSHIP; import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.SELECT; import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.UPDATE; +import static io.trino.spi.StandardErrorCode.INVALID_COLUMN_MASK; import static io.trino.spi.security.AccessDeniedException.denyAddColumn; import static io.trino.spi.security.AccessDeniedException.denyAlterColumn; import static io.trino.spi.security.AccessDeniedException.denyCatalogAccess; @@ -953,22 +955,34 @@ public List getRowFilters(SystemSecurityContext context, Catalog } @Override - public List getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName table, String columnName, Type type) + public Optional getColumnMask(SystemSecurityContext context, CatalogSchemaTableName table, String columnName, Type type) { SchemaTableName tableName = table.getSchemaTableName(); if (INFORMATION_SCHEMA_NAME.equals(tableName.getSchemaName())) { - return ImmutableList.of(); + return Optional.empty(); } Identity identity = context.getIdentity(); - return tableRules.stream() + List masks = tableRules.stream() .filter(rule -> rule.matches(identity.getUser(), identity.getEnabledRoles(), identity.getGroups(), table)) .map(rule -> rule.getColumnMask(identity.getUser(), table.getCatalogName(), table.getSchemaTableName().getSchemaName(), columnName)) // we return the first one we find .findFirst() .stream() .flatMap(Optional::stream) - .collect(toImmutableList()); + .toList(); + + if (masks.size() > 1) { + throw new TrinoException(INVALID_COLUMN_MASK, format("Multiple masks defined for %s.%s", table, columnName)); + } + + return masks.stream().findFirst(); + } + + @Override + public List getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName table, String columnName, Type type) + { + throw new UnsupportedOperationException(); } private boolean checkAnyCatalogAccess(SystemSecurityContext context, String catalogName) diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingConnectorAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingConnectorAccessControl.java index 4b0c2f49a5f6..eca7a3514867 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingConnectorAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingConnectorAccessControl.java @@ -417,6 +417,12 @@ public List getRowFilters(ConnectorSecurityContext context, Sche return delegate().getRowFilters(context, tableName); } + @Override + public Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + return delegate().getColumnMask(context, tableName, columnName, type); + } + @Override public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingSystemAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingSystemAccessControl.java index c13f9de1f930..6f8897d64eb2 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingSystemAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingSystemAccessControl.java @@ -511,6 +511,12 @@ public List getRowFilters(SystemSecurityContext context, Catalog return delegate().getRowFilters(context, tableName); } + @Override + public Optional getColumnMask(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type) + { + return delegate().getColumnMask(context, tableName, columnName, type); + } + @Override public List getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type) { diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedConnectorAccessControlTest.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedConnectorAccessControlTest.java index bb0fde56d9fc..d155baaced98 100644 --- a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedConnectorAccessControlTest.java +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedConnectorAccessControlTest.java @@ -406,8 +406,8 @@ public void testTableRulesForMixedGroupUsers() accessControl.checkCanDropTable(userGroup1Group2, myTable); accessControl.checkCanSelectFromColumns(userGroup1Group2, myTable, ImmutableSet.of()); assertEquals( - accessControl.getColumnMasks(userGroup1Group2, myTable, "col_a", VARCHAR), - ImmutableList.of()); + accessControl.getColumnMask(userGroup1Group2, myTable, "col_a", VARCHAR), + Optional.empty()); assertEquals( accessControl.getRowFilters(userGroup1Group2, myTable), ImmutableList.of()); @@ -418,7 +418,7 @@ public void testTableRulesForMixedGroupUsers() assertDenied(() -> accessControl.checkCanDropTable(userGroup2, myTable)); accessControl.checkCanSelectFromColumns(userGroup2, myTable, ImmutableSet.of()); assertViewExpressionEquals( - accessControl.getColumnMasks(userGroup2, myTable, "col_a", VARCHAR), + accessControl.getColumnMask(userGroup2, myTable, "col_a", VARCHAR).orElseThrow(), new ViewExpression(userGroup2.getIdentity().getUser(), Optional.of("test_catalog"), Optional.of("my_schema"), "'mask_a'")); assertEquals( accessControl.getRowFilters(userGroup2, myTable), @@ -433,8 +433,8 @@ public void testTableRulesForMixedGroupUsers() accessControl.checkCanDropTable(userGroup1Group3, myTable); accessControl.checkCanSelectFromColumns(userGroup1Group3, myTable, ImmutableSet.of()); assertEquals( - accessControl.getColumnMasks(userGroup1Group3, myTable, "col_a", VARCHAR), - ImmutableList.of()); + accessControl.getColumnMask(userGroup1Group3, myTable, "col_a", VARCHAR), + Optional.empty()); assertDenied(() -> accessControl.checkCanCreateTable(userGroup3, myTable, Map.of())); assertDenied(() -> accessControl.checkCanInsertIntoTable(userGroup3, myTable)); @@ -442,17 +442,18 @@ public void testTableRulesForMixedGroupUsers() assertDenied(() -> accessControl.checkCanDropTable(userGroup3, myTable)); accessControl.checkCanSelectFromColumns(userGroup3, myTable, ImmutableSet.of()); assertViewExpressionEquals( - accessControl.getColumnMasks(userGroup3, myTable, "col_a", VARCHAR), + accessControl.getColumnMask(userGroup3, myTable, "col_a", VARCHAR).orElseThrow(), new ViewExpression(userGroup3.getIdentity().getUser(), Optional.of("test_catalog"), Optional.of("my_schema"), "'mask_a'")); + + List rowFilters = accessControl.getRowFilters(userGroup3, myTable); + assertEquals(rowFilters.size(), 1); assertViewExpressionEquals( - accessControl.getRowFilters(userGroup3, myTable), + rowFilters.get(0), new ViewExpression(userGroup3.getIdentity().getUser(), Optional.of("test_catalog"), Optional.of("my_schema"), "country='US'")); } - private static void assertViewExpressionEquals(List result, ViewExpression expected) + private static void assertViewExpressionEquals(ViewExpression actual, ViewExpression expected) { - assertEquals(result.size(), 1); - ViewExpression actual = result.get(0); assertEquals(actual.getIdentity(), expected.getIdentity(), "Identity"); assertEquals(actual.getCatalog(), expected.getCatalog(), "Catalog"); assertEquals(actual.getSchema(), expected.getSchema(), "Schema"); diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedSystemAccessControlTest.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedSystemAccessControlTest.java index cb1c2c68b2f2..3b687a4db496 100644 --- a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedSystemAccessControlTest.java +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedSystemAccessControlTest.java @@ -789,19 +789,19 @@ public void testTableRulesForMixedGroupUsers() .withGroups(ImmutableSet.of("group2")).build(), Optional.empty()); assertEquals( - accessControl.getColumnMasks( + accessControl.getColumnMask( userGroup1Group2, new CatalogSchemaTableName("some-catalog", "my_schema", "my_table"), "col_a", VARCHAR), - ImmutableList.of()); + Optional.empty()); assertViewExpressionEquals( - accessControl.getColumnMasks( + accessControl.getColumnMask( userGroup2, new CatalogSchemaTableName("some-catalog", "my_schema", "my_table"), "col_a", - VARCHAR), + VARCHAR).orElseThrow(), new ViewExpression(userGroup2.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("my_schema"), "'mask_a'")); SystemSecurityContext userGroup1Group3 = new SystemSecurityContext(Identity.forUser("user_1_3") @@ -815,10 +815,12 @@ public void testTableRulesForMixedGroupUsers() new CatalogSchemaTableName("some-catalog", "my_schema", "my_table")), ImmutableList.of()); + List rowFilters = accessControl.getRowFilters( + userGroup3, + new CatalogSchemaTableName("some-catalog", "my_schema", "my_table")); + assertEquals(rowFilters.size(), 1); assertViewExpressionEquals( - accessControl.getRowFilters( - userGroup3, - new CatalogSchemaTableName("some-catalog", "my_schema", "my_table")), + rowFilters.get(0), new ViewExpression(userGroup3.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("my_schema"), "country='US'")); } @@ -1410,27 +1412,27 @@ public void testGetColumnMask() SystemAccessControl accessControl = newFileBasedSystemAccessControl("file-based-system-access-table.json"); assertEquals( - accessControl.getColumnMasks( + accessControl.getColumnMask( ALICE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns"), "masked", VARCHAR), - ImmutableList.of()); + Optional.empty()); assertViewExpressionEquals( - accessControl.getColumnMasks( + accessControl.getColumnMask( CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns"), "masked", - VARCHAR), + VARCHAR).orElseThrow(), new ViewExpression(CHARLIE.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("bobschema"), "'mask'")); assertViewExpressionEquals( - accessControl.getColumnMasks( + accessControl.getColumnMask( CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns"), "masked_with_user", - VARCHAR), + VARCHAR).orElseThrow(), new ViewExpression("mask-user", Optional.of("some-catalog"), Optional.of("bobschema"), "'mask-with-user'")); } @@ -1443,19 +1445,21 @@ public void testGetRowFilter() accessControl.getRowFilters(ALICE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns")), ImmutableList.of()); + List rowFilters = accessControl.getRowFilters(CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns")); + assertEquals(rowFilters.size(), 1); assertViewExpressionEquals( - accessControl.getRowFilters(CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns")), + rowFilters.get(0), new ViewExpression(CHARLIE.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("bobschema"), "starts_with(value, 'filter')")); + rowFilters = accessControl.getRowFilters(CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns_with_grant")); + assertEquals(rowFilters.size(), 1); assertViewExpressionEquals( - accessControl.getRowFilters(CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns_with_grant")), + rowFilters.get(0), new ViewExpression("filter-user", Optional.of("some-catalog"), Optional.of("bobschema"), "starts_with(value, 'filter-with-user')")); } - private static void assertViewExpressionEquals(List result, ViewExpression expected) + private static void assertViewExpressionEquals(ViewExpression actual, ViewExpression expected) { - assertEquals(result.size(), 1); - ViewExpression actual = result.get(0); assertEquals(actual.getIdentity(), expected.getIdentity(), "Identity"); assertEquals(actual.getCatalog(), expected.getCatalog(), "Catalog"); assertEquals(actual.getSchema(), expected.getSchema(), "Schema"); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/LegacyAccessControl.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/LegacyAccessControl.java index cb175ea34c10..62e1ca2fc621 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/LegacyAccessControl.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/LegacyAccessControl.java @@ -421,6 +421,12 @@ public List getRowFilters(ConnectorSecurityContext context, Sche return ImmutableList.of(); } + @Override + public Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + return Optional.empty(); + } + @Override public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/SqlStandardAccessControl.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/SqlStandardAccessControl.java index 435ec58c4c68..0587986e8ce7 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/SqlStandardAccessControl.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/SqlStandardAccessControl.java @@ -631,6 +631,12 @@ public List getRowFilters(ConnectorSecurityContext context, Sche return ImmutableList.of(); } + @Override + public Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + return Optional.empty(); + } + @Override public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { diff --git a/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java b/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java index 90981d6844c5..414df7fbc3c8 100644 --- a/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java +++ b/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java @@ -353,7 +353,7 @@ public void testReferencedTablesAndRoutines() ColumnInfo column = table.getColumns().get(0); assertEquals(column.getColumn(), "linenumber"); - assertTrue(column.getMasks().isEmpty()); + assertTrue(column.getMask().isEmpty()); List routines = event.getMetadata().getRoutines(); assertEquals(tables.size(), 1); @@ -385,7 +385,7 @@ public void testReferencedTablesWithViews() ColumnInfo column = table.getColumns().get(0); assertThat(column.getColumn()).isEqualTo("nationkey"); - assertThat(column.getMasks()).isEmpty(); + assertThat(column.getMask()).isEmpty(); table = tables.get(1); assertThat(table.getCatalog()).isEqualTo("mock"); @@ -398,7 +398,7 @@ public void testReferencedTablesWithViews() column = table.getColumns().get(0); assertThat(column.getColumn()).isEqualTo("test_column"); - assertThat(column.getMasks()).isEmpty(); + assertThat(column.getMask()).isEmpty(); } @Test @@ -422,7 +422,7 @@ public void testReferencedTablesWithMaterializedViews() ColumnInfo column = table.getColumns().get(0); assertThat(column.getColumn()).isEqualTo("nationkey"); - assertThat(column.getMasks()).isEmpty(); + assertThat(column.getMask()).isEmpty(); table = tables.get(1); assertThat(table.getCatalog()).isEqualTo("mock"); @@ -435,7 +435,7 @@ public void testReferencedTablesWithMaterializedViews() column = table.getColumns().get(0); assertThat(column.getColumn()).isEqualTo("test_column"); - assertThat(column.getMasks()).isEmpty(); + assertThat(column.getMask()).isEmpty(); } @Test @@ -522,7 +522,7 @@ public void testReferencedTablesWithRowFilter() ColumnInfo column = table.getColumns().get(0); assertThat(column.getColumn()).isEqualTo("name"); - assertThat(column.getMasks()).isEmpty(); + assertThat(column.getMask()).isEmpty(); table = tables.get(1); assertThat(table.getCatalog()).isEqualTo("mock"); @@ -535,7 +535,7 @@ public void testReferencedTablesWithRowFilter() column = table.getColumns().get(0); assertThat(column.getColumn()).isEqualTo("test_varchar"); - assertThat(column.getMasks()).isEmpty(); + assertThat(column.getMask()).isEmpty(); } @Test @@ -570,7 +570,7 @@ public void testReferencedTablesWithColumnMask() ColumnInfo column = table.getColumns().get(0); assertThat(column.getColumn()).isEqualTo("orderkey"); - assertThat(column.getMasks()).isEmpty(); + assertThat(column.getMask()).isEmpty(); table = tables.get(1); assertThat(table.getCatalog()).isEqualTo("mock"); @@ -583,11 +583,11 @@ public void testReferencedTablesWithColumnMask() column = table.getColumns().get(0); assertThat(column.getColumn()).isEqualTo("test_varchar"); - assertThat(column.getMasks()).hasSize(1); + assertThat(column.getMask()).isPresent(); column = table.getColumns().get(1); assertThat(column.getColumn()).isEqualTo("test_bigint"); - assertThat(column.getMasks()).isEmpty(); + assertThat(column.getMask()).isEmpty(); } @Test