From d95101b68acdaea3defa659463f301f17be7abe6 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Wed, 11 Jan 2023 17:40:22 -0800 Subject: [PATCH 1/2] Disallow multiple masks on a given column This is problematic because: * There is no guarantee about the ordering of application of masks, which could result in non-deterministic results at best or query failures in the worst case * Allowing multiple masks, especially when provided by different connectors, means that one cannot reason about a mask expression in isolation with respect to the input the expression expects. --- .../java/io/trino/security/AccessControl.java | 4 +- .../trino/security/AccessControlManager.java | 19 +++++---- .../security/ForwardingAccessControl.java | 4 +- .../InjectedConnectorAccessControl.java | 12 ++++-- .../io/trino/security/ViewAccessControl.java | 5 ++- .../java/io/trino/sql/analyzer/Analysis.java | 17 ++++---- .../trino/sql/analyzer/StatementAnalyzer.java | 16 ++++---- .../io/trino/sql/planner/RelationPlanner.java | 5 ++- .../testing/TestingAccessControlManager.java | 15 ++++--- .../connector/MockConnectorAccessControl.java | 10 +++-- .../security/TestAccessControlManager.java | 12 ------ .../io/trino/sql/query/TestColumnMask.java | 35 ---------------- core/trino-spi/pom.xml | 14 +++++++ .../spi/connector/ConnectorAccessControl.java | 17 ++++++-- .../trino/spi/eventlistener/ColumnInfo.java | 12 +++--- .../spi/security/SystemAccessControl.java | 15 ++++++- ...ClassLoaderSafeConnectorAccessControl.java | 8 ++++ .../base/security/AllowAllAccessControl.java | 6 +++ .../security/AllowAllSystemAccessControl.java | 6 +++ .../base/security/FileBasedAccessControl.java | 22 ++++++++-- .../FileBasedSystemAccessControl.java | 22 ++++++++-- .../ForwardingConnectorAccessControl.java | 6 +++ .../ForwardingSystemAccessControl.java | 6 +++ ...seFileBasedConnectorAccessControlTest.java | 21 +++++----- .../BaseFileBasedSystemAccessControlTest.java | 40 ++++++++++--------- .../hive/security/LegacyAccessControl.java | 6 +++ .../security/SqlStandardAccessControl.java | 6 +++ .../execution/TestEventListenerBasic.java | 20 +++++----- 28 files changed, 231 insertions(+), 150 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/security/AccessControl.java b/core/trino-main/src/main/java/io/trino/security/AccessControl.java index b398dd544605..3a80a40d45a7 100644 --- a/core/trino-main/src/main/java/io/trino/security/AccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/AccessControl.java @@ -566,8 +566,8 @@ default List getRowFilters(SecurityContext context, QualifiedObj return ImmutableList.of(); } - default List getColumnMasks(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type) + default Optional getColumnMask(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type) { - return ImmutableList.of(); + return Optional.empty(); } } diff --git a/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java b/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java index b1ff99897627..373405add53e 100644 --- a/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java +++ b/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java @@ -72,6 +72,7 @@ import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.configuration.ConfigurationLoader.loadPropertiesFrom; +import static io.trino.spi.StandardErrorCode.INVALID_COLUMN_MASK; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.StandardErrorCode.SERVER_STARTING_UP; import static java.lang.String.format; @@ -1248,26 +1249,30 @@ public List getRowFilters(SecurityContext context, QualifiedObje } @Override - public List getColumnMasks(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type) + public Optional getColumnMask(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type) { requireNonNull(context, "context is null"); requireNonNull(tableName, "tableName is null"); ImmutableList.Builder masks = ImmutableList.builder(); - // connector-provided masks take precedence over global masks ConnectorAccessControl connectorAccessControl = getConnectorAccessControl(context.getTransactionId(), tableName.getCatalogName()); if (connectorAccessControl != null) { - connectorAccessControl.getColumnMasks(toConnectorSecurityContext(tableName.getCatalogName(), context), tableName.asSchemaTableName(), columnName, type) - .forEach(masks::add); + connectorAccessControl.getColumnMask(toConnectorSecurityContext(tableName.getCatalogName(), context), tableName.asSchemaTableName(), columnName, type) + .ifPresent(masks::add); } for (SystemAccessControl systemAccessControl : getSystemAccessControls()) { - systemAccessControl.getColumnMasks(context.toSystemSecurityContext(), tableName.asCatalogSchemaTableName(), columnName, type) - .forEach(masks::add); + systemAccessControl.getColumnMask(context.toSystemSecurityContext(), tableName.asCatalogSchemaTableName(), columnName, type) + .ifPresent(masks::add); } - return masks.build(); + List allMasks = masks.build(); + if (allMasks.size() > 1) { + throw new TrinoException(INVALID_COLUMN_MASK, format("Column must have a single mask: %s", columnName)); + } + + return allMasks.stream().findFirst(); } private ConnectorAccessControl getConnectorAccessControl(TransactionId transactionId, String catalogName) diff --git a/core/trino-main/src/main/java/io/trino/security/ForwardingAccessControl.java b/core/trino-main/src/main/java/io/trino/security/ForwardingAccessControl.java index 1c44eb89c378..89a54ca73941 100644 --- a/core/trino-main/src/main/java/io/trino/security/ForwardingAccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/ForwardingAccessControl.java @@ -486,8 +486,8 @@ public List getRowFilters(SecurityContext context, QualifiedObje } @Override - public List getColumnMasks(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type) + public Optional getColumnMask(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type) { - return delegate().getColumnMasks(context, tableName, columnName, type); + return delegate().getColumnMask(context, tableName, columnName, type); } } diff --git a/core/trino-main/src/main/java/io/trino/security/InjectedConnectorAccessControl.java b/core/trino-main/src/main/java/io/trino/security/InjectedConnectorAccessControl.java index 76da1a5414c3..890a3276dfdf 100644 --- a/core/trino-main/src/main/java/io/trino/security/InjectedConnectorAccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/InjectedConnectorAccessControl.java @@ -492,15 +492,21 @@ public List getRowFilters(ConnectorSecurityContext context, Sche } @Override - public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + public Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { checkArgument(context == null, "context must be null"); - if (accessControl.getColumnMasks(securityContext, new QualifiedObjectName(catalogName, tableName.getSchemaName(), tableName.getTableName()), columnName, type).isEmpty()) { - return ImmutableList.of(); + if (accessControl.getColumnMask(securityContext, new QualifiedObjectName(catalogName, tableName.getSchemaName(), tableName.getTableName()), columnName, type).isEmpty()) { + return Optional.empty(); } throw new TrinoException(NOT_SUPPORTED, "Column masking not supported"); } + @Override + public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + throw new UnsupportedOperationException(); + } + private QualifiedObjectName getQualifiedObjectName(SchemaTableName schemaTableName) { return new QualifiedObjectName(catalogName, schemaTableName.getSchemaName(), schemaTableName.getTableName()); diff --git a/core/trino-main/src/main/java/io/trino/security/ViewAccessControl.java b/core/trino-main/src/main/java/io/trino/security/ViewAccessControl.java index f2b13cbf0c65..6e61a5a950cd 100644 --- a/core/trino-main/src/main/java/io/trino/security/ViewAccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/ViewAccessControl.java @@ -22,6 +22,7 @@ import io.trino.spi.type.Type; import java.util.List; +import java.util.Optional; import java.util.Set; import static com.google.common.base.Verify.verify; @@ -92,9 +93,9 @@ public List getRowFilters(SecurityContext context, QualifiedObje } @Override - public List getColumnMasks(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type) + public Optional getColumnMask(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type) { - return delegate.getColumnMasks(context, tableName, columnName, type); + return delegate.getColumnMask(context, tableName, columnName, type); } private static void wrapAccessDeniedException(Runnable runnable) diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index 474ec713a9a2..3c80193e8134 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -214,7 +214,7 @@ public class Analysis private final Map, List> rowFilters = new LinkedHashMap<>(); private final Multiset columnMaskScopes = HashMultiset.create(); - private final Map, Map>> columnMasks = new LinkedHashMap<>(); + private final Map, Map> columnMasks = new LinkedHashMap<>(); private final Map, UnnestAnalysis> unnestAnalysis = new LinkedHashMap<>(); private Optional create = Optional.empty(); @@ -1093,12 +1093,13 @@ public void unregisterTableForColumnMasking(QualifiedObjectName table, String co public void addColumnMask(Table table, String column, Expression mask) { - Map> masks = columnMasks.computeIfAbsent(NodeRef.of(table), node -> new LinkedHashMap<>()); - masks.computeIfAbsent(column, name -> new ArrayList<>()) - .add(mask); + Map masks = columnMasks.computeIfAbsent(NodeRef.of(table), node -> new LinkedHashMap<>()); + checkArgument(!masks.containsKey(column), "Mask already exists for column %s", column); + + masks.put(column, mask); } - public Map> getColumnMasks(Table table) + public Map getColumnMasks(Table table) { return columnMasks.getOrDefault(NodeRef.of(table), ImmutableMap.of()); } @@ -1118,10 +1119,8 @@ public List getReferencedTables() .distinct() .map(fieldName -> new ColumnInfo( fieldName, - columnMasks.getOrDefault(table, ImmutableMap.of()) - .getOrDefault(fieldName, ImmutableList.of()).stream() - .map(Expression::toString) - .collect(toImmutableList()))) + Optional.ofNullable(columnMasks.getOrDefault(table, ImmutableMap.of()).get(fieldName)) + .map(Expression::toString))) .collect(toImmutableList()); TableEntry info = entry.getValue(); diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 5e597f29d26a..2e086469e08d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -544,7 +544,7 @@ protected Scope visitInsert(Insert insert, Optional scope) .collect(toImmutableList()); for (ColumnSchema column : columns) { - if (!accessControl.getColumnMasks(session.toSecurityContext(), targetTable, column.getName(), column.getType()).isEmpty()) { + if (accessControl.getColumnMask(session.toSecurityContext(), targetTable, column.getName(), column.getType()).isPresent()) { throw semanticException(NOT_SUPPORTED, insert, "Insert into table with column masks is not supported"); } } @@ -785,7 +785,7 @@ protected Scope visitDelete(Delete node, Optional scope) TableSchema tableSchema = metadata.getTableSchema(session, handle); for (ColumnSchema tableColumn : tableSchema.getColumns()) { - if (!accessControl.getColumnMasks(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isEmpty()) { + if (accessControl.getColumnMask(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isPresent()) { throw semanticException(NOT_SUPPORTED, node, "Delete from table with column mask"); } } @@ -1149,7 +1149,7 @@ protected Scope visitTableExecute(TableExecute node, Optional scope) TableMetadata tableMetadata = metadata.getTableMetadata(session, tableHandle); for (ColumnMetadata tableColumn : tableMetadata.getColumns()) { - if (!accessControl.getColumnMasks(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isEmpty()) { + if (accessControl.getColumnMask(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isPresent()) { throw semanticException(NOT_SUPPORTED, node, "ALTER TABLE EXECUTE is not supported for table with column masks"); } } @@ -2222,10 +2222,10 @@ private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Optio for (int index = 0; index < relationType.getAllFieldCount(); index++) { Field field = relationType.getFieldByIndex(index); if (field.getName().isPresent()) { - List masks = accessControl.getColumnMasks(session.toSecurityContext(), name, field.getName().get(), field.getType()); + Optional mask = accessControl.getColumnMask(session.toSecurityContext(), name, field.getName().get(), field.getType()); - if (!masks.isEmpty() && checkCanSelectFromColumn(name, field.getName().orElseThrow())) { - masks.forEach(mask -> analyzeColumnMask(session.getIdentity().getUser(), table, name, field, accessControlScope, mask)); + if (mask.isPresent() && checkCanSelectFromColumn(name, field.getName().orElseThrow())) { + analyzeColumnMask(session.getIdentity().getUser(), table, name, field, accessControlScope, mask.get()); } } } @@ -3178,7 +3178,7 @@ protected Scope visitUpdate(Update update, Optional scope) // TODO: how to deal with connectors that need to see the pre-image of rows to perform the update without // flowing that data through the masking logic for (ColumnSchema tableColumn : allColumns) { - if (!accessControl.getColumnMasks(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isEmpty()) { + if (accessControl.getColumnMask(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isPresent()) { throw semanticException(NOT_SUPPORTED, update, "Updating a table with column masks is not supported"); } } @@ -3307,7 +3307,7 @@ protected Scope visitMerge(Merge merge, Optional scope) Scope joinScope = createAndAssignScope(merge, scope, targetTableScope.getRelationType().joinWith(sourceTableScope.getRelationType())); for (ColumnSchema column : dataColumnSchemas) { - if (!accessControl.getColumnMasks(session.toSecurityContext(), tableName, column.getName(), column.getType()).isEmpty()) { + if (accessControl.getColumnMask(session.toSecurityContext(), tableName, column.getName(), column.getType()).isPresent()) { throw semanticException(NOT_SUPPORTED, merge, "Cannot merge into a table with column masks"); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 3c0e138d4dc9..b1f94ad5d718 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -293,7 +293,7 @@ public RelationPlan addRowFilters(Table node, RelationPlan plan, Function> columnMasks = analysis.getColumnMasks(table); + Map columnMasks = analysis.getColumnMasks(table); // A Table can represent a WITH query, which can have anonymous fields. On the other hand, // it can't have masks. The loop below expects fields to have proper names, so bail out @@ -308,7 +308,8 @@ private RelationPlan addColumnMasks(Table table, RelationPlan plan) for (int i = 0; i < plan.getDescriptor().getAllFieldCount(); i++) { Field field = plan.getDescriptor().getFieldByIndex(i); - for (Expression mask : columnMasks.getOrDefault(field.getName().orElseThrow(), ImmutableList.of())) { + Expression mask = columnMasks.get(field.getName().orElseThrow()); + if (mask != null) { planBuilder = subqueryPlanner.handleSubqueries(planBuilder, mask, analysis.getSubqueries(mask)); Map assignments = new LinkedHashMap<>(); diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java b/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java index fea2b51d42ad..36d18a9c4fda 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java @@ -135,7 +135,7 @@ public class TestingAccessControlManager private final Set denyPrivileges = new HashSet<>(); private final Map> rowFilters = new HashMap<>(); - private final Map> columnMasks = new HashMap<>(); + private final Map columnMasks = new HashMap<>(); private Predicate deniedCatalogs = s -> true; private Predicate deniedSchemas = s -> true; private Predicate deniedTables = s -> true; @@ -175,8 +175,7 @@ public void rowFilter(QualifiedObjectName table, String identity, ViewExpression public void columnMask(QualifiedObjectName table, String column, String identity, ViewExpression mask) { - columnMasks.computeIfAbsent(new ColumnMaskKey(identity, table, column), key -> new ArrayList<>()) - .add(mask); + columnMasks.put(new ColumnMaskKey(identity, table, column), mask); } public void reset() @@ -746,13 +745,13 @@ public List getRowFilters(SecurityContext context, QualifiedObje } @Override - public List getColumnMasks(SecurityContext context, QualifiedObjectName tableName, String column, Type type) + public Optional getColumnMask(SecurityContext context, QualifiedObjectName tableName, String column, Type type) { - List viewExpressions = columnMasks.get(new ColumnMaskKey(context.getIdentity().getUser(), tableName, column)); - if (viewExpressions != null) { - return viewExpressions; + ViewExpression mask = columnMasks.get(new ColumnMaskKey(context.getIdentity().getUser(), tableName, column)); + if (mask != null) { + return Optional.of(mask); } - return super.getColumnMasks(context, tableName, column, type); + return super.getColumnMask(context, tableName, column, type); } private boolean shouldDenyPrivilege(String actorName, String entityName, TestingPrivilegeType verb) diff --git a/core/trino-main/src/test/java/io/trino/connector/MockConnectorAccessControl.java b/core/trino-main/src/test/java/io/trino/connector/MockConnectorAccessControl.java index dd50b7d68311..bc763bcbd88d 100644 --- a/core/trino-main/src/test/java/io/trino/connector/MockConnectorAccessControl.java +++ b/core/trino-main/src/test/java/io/trino/connector/MockConnectorAccessControl.java @@ -129,12 +129,16 @@ public List getRowFilters(ConnectorSecurityContext context, Sche .orElseGet(ImmutableList::of); } + @Override + public Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + return Optional.ofNullable(columnMasks.apply(tableName, columnName)); + } + @Override public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { - return Optional.ofNullable(columnMasks.apply(tableName, columnName)) - .map(ImmutableList::of) - .orElseGet(ImmutableList::of); + throw new UnsupportedOperationException(); } public void grantSchemaPrivileges(String schemaName, Set privileges, TrinoPrincipal grantee, boolean grantOption) diff --git a/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java b/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java index 8b3964568014..07ce6c40feca 100644 --- a/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java +++ b/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java @@ -58,7 +58,6 @@ import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.spi.function.FunctionKind.TABLE; import static io.trino.spi.security.AccessDeniedException.denySelectTable; -import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.testing.TestingEventListenerManager.emptyEventListenerManager; import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME; import static io.trino.transaction.InMemoryTransactionManager.createTestTransactionManager; @@ -258,17 +257,6 @@ public void checkCanShowCreateTable(ConnectorSecurityContext context, SchemaTabl { } }))); - - transaction(transactionManager, accessControlManager) - .execute(transactionId -> { - List masks = accessControlManager.getColumnMasks( - context(transactionId), - new QualifiedObjectName(TEST_CATALOG_NAME, "schema", "table"), - "column", - BIGINT); - assertEquals(masks.get(0).getExpression(), "connector mask"); - assertEquals(masks.get(1).getExpression(), "system mask"); - }); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java b/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java index 9b7078b05f7f..7da2e068774c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java @@ -196,25 +196,6 @@ public void testConditionalMask() .matches("VALUES (NULL), CAST('-781' AS BIGINT)"); } - @Test - public void testMultipleMasksOnSameColumn() - { - accessControl.reset(); - accessControl.columnMask( - new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), - "custkey", - USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "-custkey")); - - accessControl.columnMask( - new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), - "custkey", - USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "custkey * 2")); - - assertThat(assertions.query("SELECT custkey FROM orders WHERE orderkey = 1")).matches("VALUES BIGINT '-740'"); - } - @Test public void testMultipleMasksOnDifferentColumns() { @@ -614,22 +595,6 @@ public void testJoin() USER, new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey + 1")); assertThat(assertions.query("SELECT count(*) FROM orders JOIN orders USING (orderkey)")).matches("VALUES BIGINT '15000'"); - - // multiple masks - accessControl.reset(); - accessControl.columnMask( - new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), - "orderkey", - USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "-orderkey")); - - accessControl.columnMask( - new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), - "orderkey", - USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey * 2")); - - assertThat(assertions.query("SELECT count(*) FROM orders JOIN orders USING (orderkey)")).matches("VALUES BIGINT '15000'"); } @Test diff --git a/core/trino-spi/pom.xml b/core/trino-spi/pom.xml index 2875bfae937b..96b84c17f956 100644 --- a/core/trino-spi/pom.xml +++ b/core/trino-spi/pom.xml @@ -233,6 +233,20 @@ package It was not accessible outside of SPI anyway + + true + java.method.parameterTypeChanged + parameter void io.trino.spi.eventlistener.ColumnInfo::<init>(java.lang.String, ===java.util.List<java.lang.String>===) + parameter void io.trino.spi.eventlistener.ColumnInfo::<init>(java.lang.String, ===java.util.Optional<java.lang.String>===) + 1 + Removing support for multiple masks on a given column, as they are error prone + + + true + java.method.removed + method java.util.List<java.lang.String> io.trino.spi.eventlistener.ColumnInfo::getMasks() + Removing support for multiple masks on a given column, as they are error prone + diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorAccessControl.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorAccessControl.java index 6e2320b7d5a3..585f270a1ca4 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorAccessControl.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorAccessControl.java @@ -659,13 +659,24 @@ default List getRowFilters(ConnectorSecurityContext context, Sch } /** - * Get column masks associated with the given table, column and identity. + * Get column mask associated with the given table, column and identity. *

- * Each mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression + * The mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression * must be written in terms of columns in the table. * - * @return the list of masks, or empty list if not applicable + * @return the mask if present, or empty if not applicable */ + default Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + List masks = getColumnMasks(context, tableName, columnName, type); + if (masks.size() > 1) { + throw new UnsupportedOperationException("Multiple masks on a single column are no longer supported"); + } + + return masks.stream().findFirst(); + } + + @Deprecated default List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { return emptyList(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/eventlistener/ColumnInfo.java b/core/trino-spi/src/main/java/io/trino/spi/eventlistener/ColumnInfo.java index 4d26f514033a..23adca12b610 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/eventlistener/ColumnInfo.java +++ b/core/trino-spi/src/main/java/io/trino/spi/eventlistener/ColumnInfo.java @@ -17,7 +17,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import io.trino.spi.Unstable; -import java.util.List; +import java.util.Optional; /** * This class is JSON serializable for convenience and serialization compatibility is not guaranteed across versions. @@ -25,14 +25,14 @@ public class ColumnInfo { private final String column; - private final List masks; + private final Optional mask; @JsonCreator @Unstable - public ColumnInfo(String column, List masks) + public ColumnInfo(String column, Optional mask) { this.column = column; - this.masks = masks; + this.mask = mask; } @JsonProperty @@ -42,8 +42,8 @@ public String getColumn() } @JsonProperty - public List getMasks() + public Optional getMask() { - return masks; + return mask; } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/security/SystemAccessControl.java b/core/trino-spi/src/main/java/io/trino/spi/security/SystemAccessControl.java index 77df60d14a1d..1b0df6c1fba9 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/security/SystemAccessControl.java +++ b/core/trino-spi/src/main/java/io/trino/spi/security/SystemAccessControl.java @@ -877,11 +877,22 @@ default List getRowFilters(SystemSecurityContext context, Catalo /** * Get column masks associated with the given table, column and identity. *

- * Each mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression + * The mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression * must be written in terms of columns in the table. * - * @return the list of masks, or empty list if not applicable + * @return the mask if present, or empty if not applicable */ + default Optional getColumnMask(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type) + { + List masks = getColumnMasks(context, tableName, columnName, type); + if (masks.size() > 1) { + throw new UnsupportedOperationException("Multiple masks on a single column are no longer supported"); + } + + return masks.stream().findFirst(); + } + + @Deprecated default List getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type) { return List.of(); diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorAccessControl.java index 021c60b61088..1cde37af1469 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorAccessControl.java @@ -534,6 +534,14 @@ public List getRowFilters(ConnectorSecurityContext context, Sche } } + @Override + public Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getColumnMask(context, tableName, columnName, type); + } + } + @Override public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllAccessControl.java index be94198f2a7e..00b8d4ed792e 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllAccessControl.java @@ -344,6 +344,12 @@ public List getRowFilters(ConnectorSecurityContext context, Sche return ImmutableList.of(); } + @Override + public Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + return Optional.empty(); + } + @Override public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllSystemAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllSystemAccessControl.java index 53b9abb364ba..1b71496a1e64 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllSystemAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllSystemAccessControl.java @@ -461,6 +461,12 @@ public List getRowFilters(SystemSecurityContext context, Catalog return emptyList(); } + @Override + public Optional getColumnMask(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type) + { + return Optional.empty(); + } + @Override public List getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type) { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java index 4b113e5ebbf1..c4ce581c83cb 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableSet; import io.trino.plugin.base.CatalogName; import io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege; +import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSecurityContext; import io.trino.spi.connector.SchemaRoutineName; @@ -44,6 +45,7 @@ import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.OWNERSHIP; import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.SELECT; import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.UPDATE; +import static io.trino.spi.StandardErrorCode.INVALID_COLUMN_MASK; import static io.trino.spi.security.AccessDeniedException.denyAddColumn; import static io.trino.spi.security.AccessDeniedException.denyAlterColumn; import static io.trino.spi.security.AccessDeniedException.denyCommentColumn; @@ -654,21 +656,33 @@ public List getRowFilters(ConnectorSecurityContext context, Sche } @Override - public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + public Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { if (INFORMATION_SCHEMA_NAME.equals(tableName.getSchemaName())) { - return ImmutableList.of(); + return Optional.empty(); } ConnectorIdentity identity = context.getIdentity(); - return tableRules.stream() + List masks = tableRules.stream() .filter(rule -> rule.matches(identity.getUser(), identity.getEnabledSystemRoles(), identity.getGroups(), tableName)) .map(rule -> rule.getColumnMask(identity.getUser(), catalogName, tableName.getSchemaName(), columnName)) // we return the first one we find .findFirst() .stream() .flatMap(Optional::stream) - .collect(toImmutableList()); + .toList(); + + if (masks.size() > 1) { + throw new TrinoException(INVALID_COLUMN_MASK, format("Multiple masks defined for %s.%s", tableName, columnName)); + } + + return masks.stream().findFirst(); + } + + @Override + public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + throw new UnsupportedOperationException(); } private boolean canSetSessionProperty(ConnectorSecurityContext context, String property) diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java index c59ab7368210..f208ac5be4ee 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java @@ -19,6 +19,7 @@ import io.airlift.bootstrap.Bootstrap; import io.trino.plugin.base.security.CatalogAccessControlRule.AccessMode; import io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege; +import io.trino.spi.TrinoException; import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.connector.CatalogSchemaRoutineName; import io.trino.spi.connector.CatalogSchemaTableName; @@ -52,6 +53,7 @@ import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.OWNERSHIP; import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.SELECT; import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.UPDATE; +import static io.trino.spi.StandardErrorCode.INVALID_COLUMN_MASK; import static io.trino.spi.security.AccessDeniedException.denyAddColumn; import static io.trino.spi.security.AccessDeniedException.denyAlterColumn; import static io.trino.spi.security.AccessDeniedException.denyCatalogAccess; @@ -953,22 +955,34 @@ public List getRowFilters(SystemSecurityContext context, Catalog } @Override - public List getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName table, String columnName, Type type) + public Optional getColumnMask(SystemSecurityContext context, CatalogSchemaTableName table, String columnName, Type type) { SchemaTableName tableName = table.getSchemaTableName(); if (INFORMATION_SCHEMA_NAME.equals(tableName.getSchemaName())) { - return ImmutableList.of(); + return Optional.empty(); } Identity identity = context.getIdentity(); - return tableRules.stream() + List masks = tableRules.stream() .filter(rule -> rule.matches(identity.getUser(), identity.getEnabledRoles(), identity.getGroups(), table)) .map(rule -> rule.getColumnMask(identity.getUser(), table.getCatalogName(), table.getSchemaTableName().getSchemaName(), columnName)) // we return the first one we find .findFirst() .stream() .flatMap(Optional::stream) - .collect(toImmutableList()); + .toList(); + + if (masks.size() > 1) { + throw new TrinoException(INVALID_COLUMN_MASK, format("Multiple masks defined for %s.%s", table, columnName)); + } + + return masks.stream().findFirst(); + } + + @Override + public List getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName table, String columnName, Type type) + { + throw new UnsupportedOperationException(); } private boolean checkAnyCatalogAccess(SystemSecurityContext context, String catalogName) diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingConnectorAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingConnectorAccessControl.java index 4b0c2f49a5f6..eca7a3514867 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingConnectorAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingConnectorAccessControl.java @@ -417,6 +417,12 @@ public List getRowFilters(ConnectorSecurityContext context, Sche return delegate().getRowFilters(context, tableName); } + @Override + public Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + return delegate().getColumnMask(context, tableName, columnName, type); + } + @Override public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingSystemAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingSystemAccessControl.java index c13f9de1f930..6f8897d64eb2 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingSystemAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingSystemAccessControl.java @@ -511,6 +511,12 @@ public List getRowFilters(SystemSecurityContext context, Catalog return delegate().getRowFilters(context, tableName); } + @Override + public Optional getColumnMask(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type) + { + return delegate().getColumnMask(context, tableName, columnName, type); + } + @Override public List getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type) { diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedConnectorAccessControlTest.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedConnectorAccessControlTest.java index bb0fde56d9fc..d155baaced98 100644 --- a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedConnectorAccessControlTest.java +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedConnectorAccessControlTest.java @@ -406,8 +406,8 @@ public void testTableRulesForMixedGroupUsers() accessControl.checkCanDropTable(userGroup1Group2, myTable); accessControl.checkCanSelectFromColumns(userGroup1Group2, myTable, ImmutableSet.of()); assertEquals( - accessControl.getColumnMasks(userGroup1Group2, myTable, "col_a", VARCHAR), - ImmutableList.of()); + accessControl.getColumnMask(userGroup1Group2, myTable, "col_a", VARCHAR), + Optional.empty()); assertEquals( accessControl.getRowFilters(userGroup1Group2, myTable), ImmutableList.of()); @@ -418,7 +418,7 @@ public void testTableRulesForMixedGroupUsers() assertDenied(() -> accessControl.checkCanDropTable(userGroup2, myTable)); accessControl.checkCanSelectFromColumns(userGroup2, myTable, ImmutableSet.of()); assertViewExpressionEquals( - accessControl.getColumnMasks(userGroup2, myTable, "col_a", VARCHAR), + accessControl.getColumnMask(userGroup2, myTable, "col_a", VARCHAR).orElseThrow(), new ViewExpression(userGroup2.getIdentity().getUser(), Optional.of("test_catalog"), Optional.of("my_schema"), "'mask_a'")); assertEquals( accessControl.getRowFilters(userGroup2, myTable), @@ -433,8 +433,8 @@ public void testTableRulesForMixedGroupUsers() accessControl.checkCanDropTable(userGroup1Group3, myTable); accessControl.checkCanSelectFromColumns(userGroup1Group3, myTable, ImmutableSet.of()); assertEquals( - accessControl.getColumnMasks(userGroup1Group3, myTable, "col_a", VARCHAR), - ImmutableList.of()); + accessControl.getColumnMask(userGroup1Group3, myTable, "col_a", VARCHAR), + Optional.empty()); assertDenied(() -> accessControl.checkCanCreateTable(userGroup3, myTable, Map.of())); assertDenied(() -> accessControl.checkCanInsertIntoTable(userGroup3, myTable)); @@ -442,17 +442,18 @@ public void testTableRulesForMixedGroupUsers() assertDenied(() -> accessControl.checkCanDropTable(userGroup3, myTable)); accessControl.checkCanSelectFromColumns(userGroup3, myTable, ImmutableSet.of()); assertViewExpressionEquals( - accessControl.getColumnMasks(userGroup3, myTable, "col_a", VARCHAR), + accessControl.getColumnMask(userGroup3, myTable, "col_a", VARCHAR).orElseThrow(), new ViewExpression(userGroup3.getIdentity().getUser(), Optional.of("test_catalog"), Optional.of("my_schema"), "'mask_a'")); + + List rowFilters = accessControl.getRowFilters(userGroup3, myTable); + assertEquals(rowFilters.size(), 1); assertViewExpressionEquals( - accessControl.getRowFilters(userGroup3, myTable), + rowFilters.get(0), new ViewExpression(userGroup3.getIdentity().getUser(), Optional.of("test_catalog"), Optional.of("my_schema"), "country='US'")); } - private static void assertViewExpressionEquals(List result, ViewExpression expected) + private static void assertViewExpressionEquals(ViewExpression actual, ViewExpression expected) { - assertEquals(result.size(), 1); - ViewExpression actual = result.get(0); assertEquals(actual.getIdentity(), expected.getIdentity(), "Identity"); assertEquals(actual.getCatalog(), expected.getCatalog(), "Catalog"); assertEquals(actual.getSchema(), expected.getSchema(), "Schema"); diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedSystemAccessControlTest.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedSystemAccessControlTest.java index cb1c2c68b2f2..3b687a4db496 100644 --- a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedSystemAccessControlTest.java +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedSystemAccessControlTest.java @@ -789,19 +789,19 @@ public void testTableRulesForMixedGroupUsers() .withGroups(ImmutableSet.of("group2")).build(), Optional.empty()); assertEquals( - accessControl.getColumnMasks( + accessControl.getColumnMask( userGroup1Group2, new CatalogSchemaTableName("some-catalog", "my_schema", "my_table"), "col_a", VARCHAR), - ImmutableList.of()); + Optional.empty()); assertViewExpressionEquals( - accessControl.getColumnMasks( + accessControl.getColumnMask( userGroup2, new CatalogSchemaTableName("some-catalog", "my_schema", "my_table"), "col_a", - VARCHAR), + VARCHAR).orElseThrow(), new ViewExpression(userGroup2.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("my_schema"), "'mask_a'")); SystemSecurityContext userGroup1Group3 = new SystemSecurityContext(Identity.forUser("user_1_3") @@ -815,10 +815,12 @@ public void testTableRulesForMixedGroupUsers() new CatalogSchemaTableName("some-catalog", "my_schema", "my_table")), ImmutableList.of()); + List rowFilters = accessControl.getRowFilters( + userGroup3, + new CatalogSchemaTableName("some-catalog", "my_schema", "my_table")); + assertEquals(rowFilters.size(), 1); assertViewExpressionEquals( - accessControl.getRowFilters( - userGroup3, - new CatalogSchemaTableName("some-catalog", "my_schema", "my_table")), + rowFilters.get(0), new ViewExpression(userGroup3.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("my_schema"), "country='US'")); } @@ -1410,27 +1412,27 @@ public void testGetColumnMask() SystemAccessControl accessControl = newFileBasedSystemAccessControl("file-based-system-access-table.json"); assertEquals( - accessControl.getColumnMasks( + accessControl.getColumnMask( ALICE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns"), "masked", VARCHAR), - ImmutableList.of()); + Optional.empty()); assertViewExpressionEquals( - accessControl.getColumnMasks( + accessControl.getColumnMask( CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns"), "masked", - VARCHAR), + VARCHAR).orElseThrow(), new ViewExpression(CHARLIE.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("bobschema"), "'mask'")); assertViewExpressionEquals( - accessControl.getColumnMasks( + accessControl.getColumnMask( CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns"), "masked_with_user", - VARCHAR), + VARCHAR).orElseThrow(), new ViewExpression("mask-user", Optional.of("some-catalog"), Optional.of("bobschema"), "'mask-with-user'")); } @@ -1443,19 +1445,21 @@ public void testGetRowFilter() accessControl.getRowFilters(ALICE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns")), ImmutableList.of()); + List rowFilters = accessControl.getRowFilters(CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns")); + assertEquals(rowFilters.size(), 1); assertViewExpressionEquals( - accessControl.getRowFilters(CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns")), + rowFilters.get(0), new ViewExpression(CHARLIE.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("bobschema"), "starts_with(value, 'filter')")); + rowFilters = accessControl.getRowFilters(CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns_with_grant")); + assertEquals(rowFilters.size(), 1); assertViewExpressionEquals( - accessControl.getRowFilters(CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns_with_grant")), + rowFilters.get(0), new ViewExpression("filter-user", Optional.of("some-catalog"), Optional.of("bobschema"), "starts_with(value, 'filter-with-user')")); } - private static void assertViewExpressionEquals(List result, ViewExpression expected) + private static void assertViewExpressionEquals(ViewExpression actual, ViewExpression expected) { - assertEquals(result.size(), 1); - ViewExpression actual = result.get(0); assertEquals(actual.getIdentity(), expected.getIdentity(), "Identity"); assertEquals(actual.getCatalog(), expected.getCatalog(), "Catalog"); assertEquals(actual.getSchema(), expected.getSchema(), "Schema"); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/LegacyAccessControl.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/LegacyAccessControl.java index cb175ea34c10..62e1ca2fc621 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/LegacyAccessControl.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/LegacyAccessControl.java @@ -421,6 +421,12 @@ public List getRowFilters(ConnectorSecurityContext context, Sche return ImmutableList.of(); } + @Override + public Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + return Optional.empty(); + } + @Override public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/SqlStandardAccessControl.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/SqlStandardAccessControl.java index 435ec58c4c68..0587986e8ce7 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/SqlStandardAccessControl.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/SqlStandardAccessControl.java @@ -631,6 +631,12 @@ public List getRowFilters(ConnectorSecurityContext context, Sche return ImmutableList.of(); } + @Override + public Optional getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) + { + return Optional.empty(); + } + @Override public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type) { diff --git a/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java b/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java index 90981d6844c5..414df7fbc3c8 100644 --- a/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java +++ b/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java @@ -353,7 +353,7 @@ public void testReferencedTablesAndRoutines() ColumnInfo column = table.getColumns().get(0); assertEquals(column.getColumn(), "linenumber"); - assertTrue(column.getMasks().isEmpty()); + assertTrue(column.getMask().isEmpty()); List routines = event.getMetadata().getRoutines(); assertEquals(tables.size(), 1); @@ -385,7 +385,7 @@ public void testReferencedTablesWithViews() ColumnInfo column = table.getColumns().get(0); assertThat(column.getColumn()).isEqualTo("nationkey"); - assertThat(column.getMasks()).isEmpty(); + assertThat(column.getMask()).isEmpty(); table = tables.get(1); assertThat(table.getCatalog()).isEqualTo("mock"); @@ -398,7 +398,7 @@ public void testReferencedTablesWithViews() column = table.getColumns().get(0); assertThat(column.getColumn()).isEqualTo("test_column"); - assertThat(column.getMasks()).isEmpty(); + assertThat(column.getMask()).isEmpty(); } @Test @@ -422,7 +422,7 @@ public void testReferencedTablesWithMaterializedViews() ColumnInfo column = table.getColumns().get(0); assertThat(column.getColumn()).isEqualTo("nationkey"); - assertThat(column.getMasks()).isEmpty(); + assertThat(column.getMask()).isEmpty(); table = tables.get(1); assertThat(table.getCatalog()).isEqualTo("mock"); @@ -435,7 +435,7 @@ public void testReferencedTablesWithMaterializedViews() column = table.getColumns().get(0); assertThat(column.getColumn()).isEqualTo("test_column"); - assertThat(column.getMasks()).isEmpty(); + assertThat(column.getMask()).isEmpty(); } @Test @@ -522,7 +522,7 @@ public void testReferencedTablesWithRowFilter() ColumnInfo column = table.getColumns().get(0); assertThat(column.getColumn()).isEqualTo("name"); - assertThat(column.getMasks()).isEmpty(); + assertThat(column.getMask()).isEmpty(); table = tables.get(1); assertThat(table.getCatalog()).isEqualTo("mock"); @@ -535,7 +535,7 @@ public void testReferencedTablesWithRowFilter() column = table.getColumns().get(0); assertThat(column.getColumn()).isEqualTo("test_varchar"); - assertThat(column.getMasks()).isEmpty(); + assertThat(column.getMask()).isEmpty(); } @Test @@ -570,7 +570,7 @@ public void testReferencedTablesWithColumnMask() ColumnInfo column = table.getColumns().get(0); assertThat(column.getColumn()).isEqualTo("orderkey"); - assertThat(column.getMasks()).isEmpty(); + assertThat(column.getMask()).isEmpty(); table = tables.get(1); assertThat(table.getCatalog()).isEqualTo("mock"); @@ -583,11 +583,11 @@ public void testReferencedTablesWithColumnMask() column = table.getColumns().get(0); assertThat(column.getColumn()).isEqualTo("test_varchar"); - assertThat(column.getMasks()).hasSize(1); + assertThat(column.getMask()).isPresent(); column = table.getColumns().get(1); assertThat(column.getColumn()).isEqualTo("test_bigint"); - assertThat(column.getMasks()).isEmpty(); + assertThat(column.getMask()).isEmpty(); } @Test From fc57354b777145181dcd41b3efe72d8f460d153a Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Wed, 11 Jan 2023 18:03:15 -0800 Subject: [PATCH 2/2] Fix aliasing of fields in mask expressions This causes two problems: * Masks can inadvertently refer to columns that appear earlier in the list of columns as the projection is planned. This causes the mask expression to see the masked value of other columns instead of the underlying value * A possible bug in other optimizers causes mask expressions to be lost when the result of a such an expression is of type ROW and there's a dereference of a field downstream --- .../io/trino/sql/planner/RelationPlanner.java | 32 +++++++------ .../io/trino/sql/query/TestColumnMask.java | 45 +++++++++++++++++-- 2 files changed, 59 insertions(+), 18 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index b1f94ad5d718..d3e7c56fb3dd 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -100,7 +100,6 @@ import java.util.ArrayList; import java.util.HashMap; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -305,28 +304,33 @@ private RelationPlan addColumnMasks(Table table, RelationPlan plan) PlanBuilder planBuilder = newPlanBuilder(plan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext) .withScope(analysis.getAccessControlScope(table), plan.getFieldMappings()); // The fields in the access control scope has the same layout as those for the table scope + Assignments.Builder assignments = Assignments.builder(); + assignments.putIdentities(planBuilder.getRoot().getOutputSymbols()); + + List fieldMappings = new ArrayList<>(); for (int i = 0; i < plan.getDescriptor().getAllFieldCount(); i++) { Field field = plan.getDescriptor().getFieldByIndex(i); Expression mask = columnMasks.get(field.getName().orElseThrow()); + Symbol symbol = plan.getFieldMappings().get(i); + Expression projection = symbol.toSymbolReference(); if (mask != null) { planBuilder = subqueryPlanner.handleSubqueries(planBuilder, mask, analysis.getSubqueries(mask)); - - Map assignments = new LinkedHashMap<>(); - for (Symbol symbol : planBuilder.getRoot().getOutputSymbols()) { - assignments.put(symbol, symbol.toSymbolReference()); - } - assignments.put(plan.getFieldMappings().get(i), coerceIfNecessary(analysis, mask, planBuilder.rewrite(mask))); - - planBuilder = planBuilder - .withNewRoot(new ProjectNode( - idAllocator.getNextId(), - planBuilder.getRoot(), - Assignments.copyOf(assignments))); + symbol = symbolAllocator.newSymbol(symbol); + projection = coerceIfNecessary(analysis, mask, planBuilder.rewrite(mask)); } + + assignments.put(symbol, projection); + fieldMappings.add(symbol); } - return new RelationPlan(planBuilder.getRoot(), plan.getScope(), plan.getFieldMappings(), outerContext); + planBuilder = planBuilder + .withNewRoot(new ProjectNode( + idAllocator.getNextId(), + planBuilder.getRoot(), + assignments.build())); + + return new RelationPlan(planBuilder.getRoot(), plan.getScope(), fieldMappings, outerContext); } @Override diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java b/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java index 7da2e068774c..3b268a6a4462 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java @@ -26,6 +26,7 @@ import io.trino.spi.security.Identity; import io.trino.spi.security.ViewExpression; import io.trino.spi.type.BigintType; +import io.trino.spi.type.RowType; import io.trino.spi.type.VarcharType; import io.trino.testing.LocalQueryRunner; import io.trino.testing.TestingAccessControlManager; @@ -41,6 +42,7 @@ import static io.trino.connector.MockConnectorEntities.TPCH_NATION_WITH_HIDDEN_COLUMN; import static io.trino.connector.MockConnectorEntities.TPCH_WITH_HIDDEN_COLUMN_DATA; import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.SELECT_COLUMN; import static io.trino.testing.TestingAccessControlManager.privilege; import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME; @@ -86,6 +88,26 @@ public void init() Optional.of(VIEW_OWNER), false); + ConnectorViewDefinition viewWithNested = new ConnectorViewDefinition( + """ + SELECT * FROM ( + VALUES + ROW(ROW(1,2), 0), + ROW(ROW(3,4), 1) + ) t(nested, id) + """, + Optional.empty(), + Optional.empty(), + ImmutableList.of( + new ConnectorViewDefinition.ViewColumn("nested", RowType.from(ImmutableList.of( + RowType.field(INTEGER), + RowType.field(INTEGER))).getTypeId(), + Optional.empty()), + new ConnectorViewDefinition.ViewColumn("id", INTEGER.getTypeId(), Optional.empty())), + Optional.empty(), + Optional.of(VIEW_OWNER), + false); + ConnectorMaterializedViewDefinition materializedView = new ConnectorMaterializedViewDefinition( "SELECT * FROM local.tiny.nation", Optional.empty(), @@ -142,7 +164,8 @@ public void init() throw new UnsupportedOperationException(); }) .withGetViews((s, prefix) -> ImmutableMap.of( - new SchemaTableName("default", "nation_view"), view)) + new SchemaTableName("default", "nation_view"), view, + new SchemaTableName("default", "view_with_nested"), viewWithNested)) .withGetMaterializedViews((s, prefix) -> ImmutableMap.of( new SchemaTableName("default", "nation_materialized_view"), materializedView, new SchemaTableName("default", "nation_fresh_materialized_view"), freshMaterializedView, @@ -309,7 +332,7 @@ public void testMaterializedView() .setIdentity(Identity.forUser(USER).build()) .build(), "SELECT name FROM mock.default.materialized_view_with_casts WHERE nationkey = 1")) - .matches("VALUES CAST('RA' AS VARCHAR(2))"); + .matches("VALUES 'RA'"); } @Test @@ -812,7 +835,7 @@ public void testMultipleMasksUsingOtherMaskedColumns() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(Clerk#)','***#') as varchar(15))")); + new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast('###' as varchar(15))")); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), @@ -827,6 +850,20 @@ public void testMultipleMasksUsingOtherMaskedColumns() new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('951'), '***', comment)")); assertThat(assertions.query(query)) - .matches("VALUES (CAST('***' as varchar(79)), '*', CAST('***#000000951' as varchar(15)))"); + .matches("VALUES (CAST('***' as varchar(79)), '*', CAST('###' as varchar(15)))"); + } + + @Test + public void testColumnAliasing() + { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(MOCK_CATALOG, "default", "view_with_nested"), + "nested", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(id = 0, nested)")); + + assertThat(assertions.query("SELECT nested[1] FROM mock.default.view_with_nested")) + .matches("VALUES 1, NULL"); } }