diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index b373ab0e5319..e38d1ced99c6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -2279,7 +2279,7 @@ private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Relat private void analyzeCheckConstraints(Table table, QualifiedObjectName name, Scope accessControlScope, List constraints) { for (String constraint : constraints) { - ViewExpression expression = new ViewExpression(session.getIdentity().getUser(), Optional.of(name.getCatalogName()), Optional.of(name.getSchemaName()), constraint); + ViewExpression expression = new ViewExpression(Optional.empty(), Optional.of(name.getCatalogName()), Optional.of(name.getSchemaName()), constraint); analyzeCheckConstraint(table, name, accessControlScope, expression); } } @@ -4663,9 +4663,11 @@ private void analyzeRowFilter(String currentIdentity, Table table, QualifiedObje ExpressionAnalysis expressionAnalysis; try { - Identity filterIdentity = Identity.forUser(filter.getIdentity()) - .withGroups(groupProvider.getGroups(filter.getIdentity())) - .build(); + Identity filterIdentity = filter.getSecurityIdentity() + .map(filterUser -> Identity.forUser(filterUser) + .withGroups(groupProvider.getGroups(filterUser)) + .build()) + .orElseGet(session::getIdentity); expressionAnalysis = ExpressionAnalyzer.analyzeExpression( createViewSession(filter.getCatalog(), filter.getSchema(), filterIdentity, session.getPath()), // TODO: path should be included in row filter plannerContext, @@ -4714,11 +4716,13 @@ private void analyzeCheckConstraint(Table table, QualifiedObjectName name, Scope ExpressionAnalysis expressionAnalysis; try { - Identity filterIdentity = Identity.forUser(constraint.getIdentity()) - .withGroups(groupProvider.getGroups(constraint.getIdentity())) - .build(); + Identity constraintIdentity = constraint.getSecurityIdentity() + .map(user -> Identity.forUser(user) + .withGroups(groupProvider.getGroups(user)) + .build()) + .orElseGet(session::getIdentity); expressionAnalysis = ExpressionAnalyzer.analyzeExpression( - createViewSession(constraint.getCatalog(), constraint.getSchema(), filterIdentity, session.getPath()), + createViewSession(constraint.getCatalog(), constraint.getSchema(), constraintIdentity, session.getPath()), plannerContext, statementAnalyzerFactory, accessControl, @@ -4777,9 +4781,11 @@ private void analyzeColumnMask(String currentIdentity, Table table, QualifiedObj verifyNoAggregateWindowOrGroupingFunctions(session, metadata, expression, format("Column mask for '%s.%s'", table.getName(), column)); try { - Identity maskIdentity = Identity.forUser(mask.getIdentity()) - .withGroups(groupProvider.getGroups(mask.getIdentity())) - .build(); + Identity maskIdentity = mask.getSecurityIdentity() + .map(maskUser -> Identity.forUser(maskUser) + .withGroups(groupProvider.getGroups(maskUser)) + .build()) + .orElseGet(session::getIdentity); expressionAnalysis = ExpressionAnalyzer.analyzeExpression( createViewSession(mask.getCatalog(), mask.getSchema(), maskIdentity, session.getPath()), // TODO: path should be included in row filter plannerContext, diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java b/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java index 36d18a9c4fda..edcccc3ad1c5 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java @@ -839,70 +839,22 @@ public String toString() } } - private static class RowFilterKey + private record RowFilterKey(String identity, QualifiedObjectName table) { - private final String identity; - private final QualifiedObjectName table; - - public RowFilterKey(String identity, QualifiedObjectName table) - { - this.identity = requireNonNull(identity, "identity is null"); - this.table = requireNonNull(table, "table is null"); - } - - @Override - public boolean equals(Object o) - { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - RowFilterKey that = (RowFilterKey) o; - return identity.equals(that.identity) && - table.equals(that.table); - } - - @Override - public int hashCode() + private RowFilterKey { - return Objects.hash(identity, table); + requireNonNull(identity, "identity is null"); + requireNonNull(table, "table is null"); } } - private static class ColumnMaskKey + private record ColumnMaskKey(String identity, QualifiedObjectName table, String column) { - private final String identity; - private final QualifiedObjectName table; - private final String column; - - public ColumnMaskKey(String identity, QualifiedObjectName table, String column) - { - this.identity = identity; - this.table = table; - this.column = column; - } - - @Override - public boolean equals(Object o) - { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - ColumnMaskKey that = (ColumnMaskKey) o; - return identity.equals(that.identity) && - table.equals(that.table) && - column.equals(that.column); - } - - @Override - public int hashCode() + private ColumnMaskKey { - return Objects.hash(identity, table, column); + requireNonNull(identity, "identity is null"); + requireNonNull(table, "table is null"); + requireNonNull(column, "column is null"); } } } diff --git a/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java b/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java index 07ce6c40feca..af1c2534a755 100644 --- a/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java +++ b/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java @@ -231,7 +231,7 @@ public SystemAccessControl create(Map config) @Override public List getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName tableName, String column, Type type) { - return ImmutableList.of(new ViewExpression("user", Optional.empty(), Optional.empty(), "system mask")); + return ImmutableList.of(new ViewExpression(Optional.of("user"), Optional.empty(), Optional.empty(), "system mask")); } @Override @@ -249,7 +249,7 @@ public void checkCanSetSystemSessionProperty(SystemSecurityContext context, Stri @Override public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String column, Type type) { - return ImmutableList.of(new ViewExpression("user", Optional.empty(), Optional.empty(), "connector mask")); + return ImmutableList.of(new ViewExpression(Optional.of("user"), Optional.empty(), Optional.empty(), "connector mask")); } @Override diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestMaterializedViews.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestMaterializedViews.java index 19fece81b2d5..81bd4d050432 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestMaterializedViews.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestMaterializedViews.java @@ -236,7 +236,7 @@ public void testMaterializedViewWithCasts() new QualifiedObjectName(TEST_CATALOG_NAME, SCHEMA, "materialized_view_with_casts"), "a", "user", - new ViewExpression("user", Optional.empty(), Optional.empty(), "a + 1")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "a + 1")); assertPlan("SELECT * FROM materialized_view_with_casts", anyTree( project( diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java b/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java index 3b268a6a4462..9c176068c678 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java @@ -194,7 +194,7 @@ public void testSimpleMask() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "custkey", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "-custkey")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "-custkey")); assertThat(assertions.query("SELECT custkey FROM orders WHERE orderkey = 1")).matches("VALUES BIGINT '-370'"); accessControl.reset(); @@ -202,7 +202,7 @@ public void testSimpleMask() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "custkey", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "NULL")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "NULL")); assertThat(assertions.query("SELECT custkey FROM orders WHERE orderkey = 1")).matches("VALUES CAST(NULL AS BIGINT)"); } @@ -214,7 +214,7 @@ public void testConditionalMask() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "custkey", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "IF (orderkey < 2, null, -custkey)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "IF (orderkey < 2, null, -custkey)")); assertThat(assertions.query("SELECT custkey FROM orders LIMIT 2")) .matches("VALUES (NULL), CAST('-781' AS BIGINT)"); } @@ -227,13 +227,13 @@ public void testMultipleMasksOnDifferentColumns() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "custkey", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "-custkey")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "-custkey")); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderstatus", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "'X'")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "'X'")); assertThat(assertions.query("SELECT custkey, orderstatus FROM orders WHERE orderkey = 1")) .matches("VALUES (BIGINT '-370', 'X')"); @@ -247,13 +247,13 @@ public void testReferenceInUsingClause() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "IF(orderkey = 1, -orderkey)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "IF(orderkey = 1, -orderkey)")); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "lineitem"), "orderkey", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "IF(orderkey = 1, -orderkey)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "IF(orderkey = 1, -orderkey)")); assertThat(assertions.query("SELECT count(*) FROM orders JOIN lineitem USING (orderkey)")).matches("VALUES BIGINT '6'"); } @@ -266,7 +266,7 @@ public void testCoercibleType() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "CAST(clerk AS VARCHAR(5))")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "CAST(clerk AS VARCHAR(5))")); assertThat(assertions.query("SELECT clerk FROM orders WHERE orderkey = 1")).matches("VALUES CAST('Clerk' AS VARCHAR(15))"); } @@ -279,7 +279,7 @@ public void testSubquery() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT cast(max(name) AS VARCHAR(15)) FROM nation)")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT cast(max(name) AS VARCHAR(15)) FROM nation)")); assertThat(assertions.query("SELECT clerk FROM orders WHERE orderkey = 1")).matches("VALUES CAST('VIETNAM' AS VARCHAR(15))"); // correlated @@ -288,7 +288,7 @@ public void testSubquery() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT cast(max(name) AS VARCHAR(15)) FROM nation WHERE nationkey = orderkey)")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT cast(max(name) AS VARCHAR(15)) FROM nation WHERE nationkey = orderkey)")); assertThat(assertions.query("SELECT clerk FROM orders WHERE orderkey = 1")).matches("VALUES CAST('ARGENTINA' AS VARCHAR(15))"); } @@ -301,17 +301,17 @@ public void testMaterializedView() new QualifiedObjectName(MOCK_CATALOG, "default", "nation_fresh_materialized_view"), "name", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "reverse(name)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "reverse(name)")); accessControl.columnMask( new QualifiedObjectName(MOCK_CATALOG, "default", "nation_materialized_view"), "name", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "reverse(name)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "reverse(name)")); accessControl.columnMask( new QualifiedObjectName(MOCK_CATALOG, "default", "materialized_view_with_casts"), "name", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "reverse(name)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "reverse(name)")); assertThat(assertions.query( Session.builder(SESSION) @@ -344,7 +344,7 @@ public void testView() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "nation"), "name", VIEW_OWNER, - new ViewExpression(VIEW_OWNER, Optional.empty(), Optional.empty(), "reverse(name)")); + new ViewExpression(Optional.of(VIEW_OWNER), Optional.empty(), Optional.empty(), "reverse(name)")); assertThat(assertions.query( Session.builder(SESSION) @@ -359,7 +359,7 @@ public void testView() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "nation"), "name", VIEW_OWNER, - new ViewExpression(VIEW_OWNER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "reverse(name)")); + new ViewExpression(Optional.of(VIEW_OWNER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "reverse(name)")); assertThat(assertions.query( Session.builder(SESSION) @@ -374,7 +374,7 @@ public void testView() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "nation"), "name", RUN_AS_USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "reverse(name)")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "reverse(name)")); assertThat(assertions.query( Session.builder(SESSION) @@ -389,7 +389,7 @@ public void testView() new QualifiedObjectName(MOCK_CATALOG, "default", "nation_view"), "name", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "reverse(name)")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "reverse(name)")); assertThat(assertions.query("SELECT name FROM mock.default.nation_view WHERE nationkey = 1")).matches("VALUES CAST('ANITNEGRA' AS VARCHAR(25))"); } @@ -401,7 +401,7 @@ public void testTableReferenceInWithClause() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "custkey", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "-custkey")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "-custkey")); assertThat(assertions.query("WITH t AS (SELECT custkey FROM orders WHERE orderkey = 1) SELECT * FROM t")).matches("VALUES BIGINT '-370'"); } @@ -413,7 +413,7 @@ public void testOtherSchema() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("sf1"), "(SELECT count(*) FROM customer)")); // count is 15000 only when evaluating against sf1 + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("sf1"), "(SELECT count(*) FROM customer)")); // count is 15000 only when evaluating against sf1 assertThat(assertions.query("SELECT max(orderkey) FROM orders")).matches("VALUES BIGINT '150000'"); } @@ -425,13 +425,13 @@ public void testDifferentIdentity() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", RUN_AS_USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "100")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "100")); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT sum(orderkey) FROM orders)")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT sum(orderkey) FROM orders)")); assertThat(assertions.query("SELECT max(orderkey) FROM orders")).matches("VALUES BIGINT '1500000'"); } @@ -444,7 +444,7 @@ public void testRecursion() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM orders)")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM orders)")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessageMatching(".*\\QColumn mask for 'local.tiny.orders.orderkey' is recursive\\E.*"); @@ -455,7 +455,7 @@ public void testRecursion() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM local.tiny.orders)")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM local.tiny.orders)")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessageMatching(".*\\QColumn mask for 'local.tiny.orders.orderkey' is recursive\\E.*"); @@ -466,13 +466,13 @@ public void testRecursion() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", RUN_AS_USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM orders)")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM orders)")); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM orders)")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM orders)")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessageMatching(".*\\QColumn mask for 'local.tiny.orders.orderkey' is recursive\\E.*"); @@ -486,7 +486,7 @@ public void testLimitedScope() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "customer"), "custkey", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey")); assertThatThrownBy(() -> assertions.query( "SELECT (SELECT min(custkey) FROM customer WHERE customer.custkey = orders.custkey) FROM orders")) .hasMessage("line 1:34: Invalid column mask for 'local.tiny.customer.custkey': Column 'orderkey' cannot be resolved"); @@ -500,7 +500,7 @@ public void testSqlInjection() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "nation"), "name", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT name FROM region WHERE regionkey = 0)")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT name FROM region WHERE regionkey = 0)")); assertThat(assertions.query( "WITH region(regionkey, name) AS (VALUES (0, 'ASIA'))" + "SELECT name FROM nation ORDER BY name LIMIT 1")) @@ -516,7 +516,7 @@ public void testInvalidMasks() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "$$$")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "$$$")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessage("line 1:22: Invalid column mask for 'local.tiny.orders.orderkey': mismatched input '$'. Expecting: "); @@ -527,7 +527,7 @@ public void testInvalidMasks() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "unknown_column")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "unknown_column")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessage("line 1:22: Invalid column mask for 'local.tiny.orders.orderkey': Column 'unknown_column' cannot be resolved"); @@ -538,7 +538,7 @@ public void testInvalidMasks() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "'foo'")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "'foo'")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessage("line 1:22: Expected column mask for 'local.tiny.orders.orderkey' to be of type bigint, but was varchar(3)"); @@ -549,7 +549,7 @@ public void testInvalidMasks() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "count(*) > 0")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "count(*) > 0")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessage("line 1:10: Column mask for 'orders.orderkey' cannot contain aggregations, window functions or grouping operations: [count(*)]"); @@ -560,7 +560,7 @@ public void testInvalidMasks() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "row_number() OVER () > 0")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "row_number() OVER () > 0")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessage("line 1:22: Column mask for 'orders.orderkey' cannot contain aggregations, window functions or grouping operations: [row_number() OVER ()]"); @@ -571,7 +571,7 @@ public void testInvalidMasks() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "grouping(orderkey) = 0")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "grouping(orderkey) = 0")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessage("line 1:20: Column mask for 'orders.orderkey' cannot contain aggregations, window functions or grouping operations: [GROUPING (orderkey)]"); @@ -585,7 +585,7 @@ public void testShowStats() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "7")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "7")); assertThat(assertions.query("SHOW STATS FOR (SELECT * FROM orders)")) .containsAll(""" @@ -616,7 +616,7 @@ public void testJoin() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey + 1")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey + 1")); assertThat(assertions.query("SELECT count(*) FROM orders JOIN orders USING (orderkey)")).matches("VALUES BIGINT '15000'"); } @@ -629,7 +629,7 @@ public void testColumnMaskingUsingRestrictedColumn() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "custkey")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "custkey")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessage("Access Denied: Cannot select from columns [orderkey, custkey] in table or view local.tiny.orders"); } @@ -642,7 +642,7 @@ public void testInsertWithColumnMasking() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "clerk")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "clerk")); assertThatThrownBy(() -> assertions.query("INSERT INTO orders SELECT * FROM orders")) .hasMessage("Insert into table with column masks is not supported"); } @@ -655,7 +655,7 @@ public void testDeleteWithColumnMasking() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "clerk")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "clerk")); assertThatThrownBy(() -> assertions.query("DELETE FROM orders")) .hasMessage("line 1:1: Delete from table with column mask"); } @@ -668,7 +668,7 @@ public void testUpdateWithColumnMasking() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "clerk")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "clerk")); assertThatThrownBy(() -> assertions.query("UPDATE orders SET clerk = 'X'")) .hasMessage("line 1:1: Updating a table with column masks is not supported"); assertThatThrownBy(() -> assertions.query("UPDATE orders SET orderkey = -orderkey")) @@ -687,7 +687,7 @@ public void testNotReferencedAndDeniedColumnMasking() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "clerk")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "clerk")); assertThat(assertions.query("SELECT orderkey FROM orders WHERE orderkey = 1")).matches("VALUES BIGINT '1'"); // mask on long column @@ -697,7 +697,7 @@ public void testNotReferencedAndDeniedColumnMasking() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "totalprice", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "totalprice")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "totalprice")); assertThat(assertions.query("SELECT orderkey FROM orders WHERE orderkey = 1")).matches("VALUES BIGINT '1'"); // mask on not used varchar column with subquery masking @@ -708,7 +708,7 @@ public void testNotReferencedAndDeniedColumnMasking() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "(SELECT orderstatus FROM local.tiny.orders)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "(SELECT orderstatus FROM local.tiny.orders)")); assertThat(assertions.query("SELECT orderkey FROM orders WHERE orderkey = 1")).matches("VALUES BIGINT '1'"); } @@ -720,7 +720,7 @@ public void testColumnMaskWithHiddenColumns() new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation_with_hidden_column"), "name", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "'POLAND'")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "'POLAND'")); assertions.query("SELECT * FROM mock.tiny.nation_with_hidden_column WHERE nationkey = 1") .assertThat() @@ -754,19 +754,19 @@ public void testMultipleMasksUsingOtherMaskedColumns() new QualifiedObjectName(TEST_CATALOG_NAME, "tiny", "orders"), "comment", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast(regexp_replace(comment,'(password: [^ ]+)','password: ****') as varchar(79))")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "cast(regexp_replace(comment,'(password: [^ ]+)','password: ****') as varchar(79))")); accessControl.columnMask( new QualifiedObjectName(TEST_CATALOG_NAME, "tiny", "orders"), "orderstatus", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(comment,'(country: [^ ]+)') IN ('country: 1'), '*', orderstatus)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(regexp_extract(comment,'(country: [^ ]+)') IN ('country: 1'), '*', orderstatus)")); accessControl.columnMask( new QualifiedObjectName(TEST_CATALOG_NAME, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(comment,'(country: [^ ]+)') IN ('country: 1'), '***', clerk)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(regexp_extract(comment,'(country: [^ ]+)') IN ('country: 1'), '***', clerk)")); assertThat(assertions.query(query)).matches(expected); @@ -777,13 +777,13 @@ public void testMultipleMasksUsingOtherMaskedColumns() new QualifiedObjectName(TEST_CATALOG_NAME, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(password: [^ ]+)','password: ****') as varchar(15))")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(password: [^ ]+)','password: ****') as varchar(15))")); accessControl.columnMask( new QualifiedObjectName(TEST_CATALOG_NAME, "tiny", "orders"), "comment", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'(country: [^ ]+)') IN ('country: 1'), '***', comment)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'(country: [^ ]+)') IN ('country: 1'), '***', comment)")); assertThat(assertions.query(query)).matches(expected); @@ -794,19 +794,19 @@ public void testMultipleMasksUsingOtherMaskedColumns() new QualifiedObjectName(TEST_CATALOG_NAME, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(password: [^ ]+)','password: ****') as varchar(15))")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(password: [^ ]+)','password: ****') as varchar(15))")); accessControl.columnMask( new QualifiedObjectName(TEST_CATALOG_NAME, "tiny", "orders"), "orderstatus", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'(country: [^ ]+)') IN ('country: 1'), '*', orderstatus)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'(country: [^ ]+)') IN ('country: 1'), '*', orderstatus)")); accessControl.columnMask( new QualifiedObjectName(TEST_CATALOG_NAME, "tiny", "orders"), "comment", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'(country: [^ ]+)') IN ('country: 1'), '***', comment)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'(country: [^ ]+)') IN ('country: 1'), '***', comment)")); assertThat(assertions.query(query)).matches(expected); @@ -817,13 +817,13 @@ public void testMultipleMasksUsingOtherMaskedColumns() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(Clerk#)','***#') as varchar(15))")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(Clerk#)','***#') as varchar(15))")); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "comment", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('951'), '***', comment)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('951'), '***', comment)")); assertThat(assertions.query(query)) .matches("VALUES (CAST('***' as varchar(79)), 'O', CAST('***#000000951' as varchar(15)))"); @@ -835,19 +835,19 @@ public void testMultipleMasksUsingOtherMaskedColumns() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast('###' as varchar(15))")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "cast('###' as varchar(15))")); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderstatus", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('951'), '*', orderstatus)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('951'), '*', orderstatus)")); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "comment", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('951'), '***', comment)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('951'), '***', comment)")); assertThat(assertions.query(query)) .matches("VALUES (CAST('***' as varchar(79)), '*', CAST('###' as varchar(15)))"); @@ -861,7 +861,7 @@ public void testColumnAliasing() new QualifiedObjectName(MOCK_CATALOG, "default", "view_with_nested"), "nested", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(id = 0, nested)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(id = 0, nested)")); assertThat(assertions.query("SELECT nested[1] FROM mock.default.view_with_nested")) .matches("VALUES 1, NULL"); diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestFilterInaccessibleColumns.java b/core/trino-main/src/test/java/io/trino/sql/query/TestFilterInaccessibleColumns.java index 89eeaceef02b..a6df900756e0 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestFilterInaccessibleColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestFilterInaccessibleColumns.java @@ -100,9 +100,7 @@ public void testDescribeBaseline() assertThat(assertions.query("DESCRIBE nation")) .matches(materializedRows -> materializedRows .getMaterializedRows().stream() - .filter(materializedRow -> materializedRow.getField(0).equals("comment")) - .findFirst() - .isPresent()); + .anyMatch(materializedRow -> materializedRow.getField(0).equals("comment"))); } @Test @@ -112,9 +110,7 @@ public void testDescribe() assertThat(assertions.query("DESCRIBE nation")) .matches(materializedRows -> materializedRows .getMaterializedRows().stream() - .filter(materializedRow -> materializedRow.getField(0).equals("comment")) - .findFirst() - .isEmpty()); + .noneMatch(materializedRow -> materializedRow.getField(0).equals("comment"))); } @Test @@ -123,9 +119,7 @@ public void testShowColumnsBaseline() assertThat(assertions.query("SHOW COLUMNS FROM nation")) .matches(materializedRows -> materializedRows .getMaterializedRows().stream() - .filter(materializedRow -> materializedRow.getField(0).equals("comment")) - .findFirst() - .isPresent()); + .anyMatch(materializedRow -> materializedRow.getField(0).equals("comment"))); } @Test @@ -135,9 +129,7 @@ public void testShowColumns() assertThat(assertions.query("SHOW COLUMNS FROM nation")) .matches(materializedRows -> materializedRows .getMaterializedRows().stream() - .filter(materializedRow -> materializedRow.getField(0).equals("comment")) - .findFirst() - .isEmpty()); + .noneMatch(materializedRow -> materializedRow.getField(0).equals("comment"))); } /** @@ -157,59 +149,74 @@ public void testFilterExplicitSelect() } @Test - public void testRowFilterOnNotAccessibleColumn() + public void testRowFilterWithAccessToInaccessibleColumn() { accessControl.rowFilter(new QualifiedObjectName(TEST_CATALOG_NAME, TINY_SCHEMA_NAME, "nation"), USER, - new ViewExpression(ADMIN, Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "comment IS NOT null")); + new ViewExpression(Optional.of(ADMIN), Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "comment IS NOT null")); accessControl.deny(privilege(USER, "nation.comment", SELECT_COLUMN)); assertThat(assertions.query("SELECT * FROM nation WHERE name = 'FRANCE'")) .matches("VALUES (BIGINT '6', CAST('FRANCE' AS VARCHAR(25)), BIGINT '3')"); } @Test - public void testRowFilterOnNotAccessibleColumnKO() + public void testRowFilterWithoutAccessToInaccessibleColumn() { accessControl.rowFilter(new QualifiedObjectName(TEST_CATALOG_NAME, TINY_SCHEMA_NAME, "nation"), USER, - new ViewExpression(USER, Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "comment IS NOT null")); + new ViewExpression(Optional.of(USER), Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "comment IS NOT null")); accessControl.deny(privilege(USER, "nation.comment", SELECT_COLUMN)); assertThatThrownBy(() -> assertions.query("SELECT * FROM nation WHERE name = 'FRANCE'")) .hasMessage("Access Denied: Cannot select from columns [nationkey, regionkey, name, comment] in table or view test-catalog.tiny.nation"); } + @Test + public void testRowFilterAsSessionUserOnInaccessibleColumn() + { + accessControl.deny(privilege(USER, "nation.comment", SELECT_COLUMN)); + QualifiedObjectName table = new QualifiedObjectName(TEST_CATALOG_NAME, TINY_SCHEMA_NAME, "nation"); + ViewExpression filter = new ViewExpression(Optional.empty(), Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "comment IS NOT null"); + accessControl.rowFilter(table, ADMIN, filter); + accessControl.rowFilter(table, USER, filter); + + assertThatThrownBy(() -> assertions.query(user(USER), "SELECT * FROM nation WHERE name = 'FRANCE'")) + .hasMessage("Access Denied: Cannot select from columns [nationkey, regionkey, name, comment] in table or view test-catalog.tiny.nation"); + assertThat(assertions.query(user(ADMIN), "SELECT * FROM nation WHERE name = 'FRANCE'")) + .matches("VALUES (BIGINT '6', CAST('FRANCE' AS VARCHAR(25)), BIGINT '3', CAST('refully final requests. regular, ironi' AS VARCHAR(152)))"); + } + @Test public void testMaskingOnAccessibleColumn() { accessControl.columnMask(new QualifiedObjectName(TEST_CATALOG_NAME, TINY_SCHEMA_NAME, "nation"), "nationkey", USER, - new ViewExpression(ADMIN, Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "-nationkey")); + new ViewExpression(Optional.of(ADMIN), Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "-nationkey")); assertThat(assertions.query("SELECT * FROM nation WHERE name = 'FRANCE'")) .matches("VALUES (BIGINT '-6',CAST('FRANCE' AS VARCHAR(25)), BIGINT '3', CAST('refully final requests. regular, ironi' AS VARCHAR(152)))"); } @Test - public void testMaskingWithCaseOnNotAccessibleColumnKO() + public void testMaskingWithoutAccessToInaccessibleColumn() { accessControl.deny(privilege(USER, "nation.nationkey", SELECT_COLUMN)); accessControl.columnMask(new QualifiedObjectName(TEST_CATALOG_NAME, TINY_SCHEMA_NAME, "nation"), "comment", USER, - new ViewExpression(USER, Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "CASE nationkey WHEN 6 THEN 'masked-comment' ELSE comment END")); + new ViewExpression(Optional.of(USER), Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "CASE nationkey WHEN 6 THEN 'masked-comment' ELSE comment END")); assertThatThrownBy(() -> assertions.query("SELECT * FROM nation WHERE name = 'FRANCE'")) .hasMessage("Access Denied: Cannot select from columns [nationkey, regionkey, name, comment] in table or view test-catalog.tiny.nation"); } @Test - public void testMaskingWithCaseOnNotAccessibleColumn() + public void testMaskingWithAccessToInaccessibleColumn() { accessControl.deny(privilege(USER, "nation.nationkey", SELECT_COLUMN)); accessControl.columnMask(new QualifiedObjectName(TEST_CATALOG_NAME, TINY_SCHEMA_NAME, "nation"), "comment", USER, - new ViewExpression(ADMIN, Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "CASE nationkey WHEN 6 THEN 'masked-comment' ELSE comment END")); + new ViewExpression(Optional.of(ADMIN), Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "CASE nationkey WHEN 6 THEN 'masked-comment' ELSE comment END")); assertThat(assertions.query("SELECT * FROM nation WHERE name = 'FRANCE'")) .matches("VALUES (CAST('FRANCE' AS VARCHAR(25)), BIGINT '3', CAST('masked-comment' AS VARCHAR(152)))"); @@ -218,6 +225,21 @@ public void testMaskingWithCaseOnNotAccessibleColumn() .matches("VALUES (CAST('CANADA' AS VARCHAR(25)), BIGINT '1', CAST('eas hang ironic, silent packages. slyly regular packages are furiously over the tithes. fluffily bold' AS VARCHAR(152)))"); } + @Test + public void testMaskingAsSessionUserWithCaseOnInaccessibleColumn() + { + accessControl.deny(privilege(USER, "nation.nationkey", SELECT_COLUMN)); + QualifiedObjectName table = new QualifiedObjectName(TEST_CATALOG_NAME, TINY_SCHEMA_NAME, "nation"); + ViewExpression mask = new ViewExpression(Optional.empty(), Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "CASE nationkey WHEN 3 THEN 'masked-comment' ELSE comment END"); + accessControl.columnMask(table, "comment", ADMIN, mask); + accessControl.columnMask(table, "comment", USER, mask); + + assertThatThrownBy(() -> assertions.query(user(USER), "SELECT * FROM nation WHERE name = 'FRANCE'")) + .hasMessage("Access Denied: Cannot select from columns [nationkey, regionkey, name, comment] in table or view test-catalog.tiny.nation"); + assertThat(assertions.query(user(ADMIN), "SELECT * FROM nation WHERE name = 'CANADA'")) + .matches("VALUES (BIGINT '3', CAST('CANADA' AS VARCHAR(25)), BIGINT '1', CAST('masked-comment' AS VARCHAR(152)))"); + } + @Test public void testPredicateOnInaccessibleColumn() { @@ -265,4 +287,11 @@ public void testFunctionOnInaccessibleColumn() assertThatThrownBy(() -> assertions.query("SELECT * FROM (SELECT concat(name,'-test') FROM nation WHERE name = 'FRANCE')")) .hasMessage("Access Denied: Cannot select from columns [name] in table or view test-catalog.tiny.nation"); } + + private Session user(String user) + { + return Session.builder(assertions.getDefaultSession()) + .setIdentity(Identity.ofUser(user)) + .build(); + } } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestRowFilter.java b/core/trino-main/src/test/java/io/trino/sql/query/TestRowFilter.java index aa8b388bf526..ed9053b140b0 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestRowFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestRowFilter.java @@ -156,14 +156,14 @@ public void testSimpleFilter() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey < 10")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "orderkey < 10")); assertThat(assertions.query("SELECT count(*) FROM orders")).matches("VALUES BIGINT '7'"); accessControl.reset(); accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "NULL")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "NULL")); assertThat(assertions.query("SELECT count(*) FROM orders")).matches("VALUES BIGINT '0'"); } @@ -174,12 +174,12 @@ public void testMultipleFilters() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey < 10")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "orderkey < 10")); accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey > 5")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "orderkey > 5")); assertThat(assertions.query("SELECT count(*) FROM orders")).matches("VALUES BIGINT '2'"); } @@ -191,7 +191,7 @@ public void testCorrelatedSubquery() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "EXISTS (SELECT 1 FROM nation WHERE nationkey = orderkey)")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "EXISTS (SELECT 1 FROM nation WHERE nationkey = orderkey)")); assertThat(assertions.query("SELECT count(*) FROM orders")).matches("VALUES BIGINT '7'"); } @@ -203,7 +203,7 @@ public void testView() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "nation"), VIEW_OWNER, - new ViewExpression(VIEW_OWNER, Optional.empty(), Optional.empty(), "nationkey = 1")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey = 1")); assertThat(assertions.query( Session.builder(SESSION) @@ -217,7 +217,7 @@ public void testView() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "nation"), VIEW_OWNER, - new ViewExpression(VIEW_OWNER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "nationkey = 1")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "nationkey = 1")); assertThat(assertions.query( Session.builder(SESSION) @@ -231,7 +231,7 @@ public void testView() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "nation"), RUN_AS_USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "nationkey = 1")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "nationkey = 1")); Session session = Session.builder(SESSION) .setIdentity(Identity.forUser(RUN_AS_USER).build()) @@ -244,7 +244,7 @@ public void testView() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "default", "nation_view"), USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "nationkey = 1")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "nationkey = 1")); assertThat(assertions.query("SELECT name FROM mock.default.nation_view")).matches("VALUES CAST('ARGENTINA' AS VARCHAR(25))"); } @@ -255,7 +255,7 @@ public void testTableReferenceInWithClause() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey = 1")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "orderkey = 1")); assertThat(assertions.query("WITH t AS (SELECT count(*) FROM orders) SELECT * FROM t")).matches("VALUES BIGINT '1'"); } @@ -266,7 +266,7 @@ public void testOtherSchema() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("sf1"), "(SELECT count(*) FROM customer) = 150000")); // Filter is TRUE only if evaluating against sf1.customer + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("sf1"), "(SELECT count(*) FROM customer) = 150000")); // Filter is TRUE only if evaluating against sf1.customer assertThat(assertions.query("SELECT count(*) FROM orders")).matches("VALUES BIGINT '15000'"); } @@ -277,12 +277,12 @@ public void testDifferentIdentity() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), RUN_AS_USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey = 1")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey = 1")); accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); assertThat(assertions.query("SELECT count(*) FROM orders")).matches("VALUES BIGINT '1'"); } @@ -294,7 +294,7 @@ public void testRecursion() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessageMatching(".*\\QRow filter for 'local.tiny.orders' is recursive\\E.*"); @@ -304,7 +304,7 @@ public void testRecursion() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT local.tiny.orderkey FROM orders)")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT local.tiny.orderkey FROM orders)")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessageMatching(".*\\QRow filter for 'local.tiny.orders' is recursive\\E.*"); @@ -313,12 +313,12 @@ public void testRecursion() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), RUN_AS_USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessageMatching(".*\\QRow filter for 'local.tiny.orders' is recursive\\E.*"); @@ -331,7 +331,7 @@ public void testLimitedScope() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "customer"), USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey = 1")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey = 1")); assertThatThrownBy(() -> assertions.query( "SELECT (SELECT min(name) FROM customer WHERE customer.custkey = orders.custkey) FROM orders")) .hasMessage("line 1:31: Invalid row filter for 'local.tiny.customer': Column 'orderkey' cannot be resolved"); @@ -344,7 +344,7 @@ public void testSqlInjection() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "nation"), USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "regionkey IN (SELECT regionkey FROM region WHERE name = 'ASIA')")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "regionkey IN (SELECT regionkey FROM region WHERE name = 'ASIA')")); assertThat(assertions.query( "WITH region(regionkey, name) AS (VALUES (0, 'ASIA'), (1, 'ASIA'), (2, 'ASIA'), (3, 'ASIA'), (4, 'ASIA'))" + "SELECT name FROM nation ORDER BY name LIMIT 1")) @@ -359,7 +359,7 @@ public void testInvalidFilter() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "$$$")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "$$$")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessage("line 1:22: Invalid row filter for 'local.tiny.orders': mismatched input '$'. Expecting: "); @@ -369,7 +369,7 @@ public void testInvalidFilter() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "unknown_column")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "unknown_column")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessage("line 1:22: Invalid row filter for 'local.tiny.orders': Column 'unknown_column' cannot be resolved"); @@ -379,7 +379,7 @@ public void testInvalidFilter() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "1")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "1")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessage("line 1:22: Expected row filter for 'local.tiny.orders' to be of type BOOLEAN, but was integer"); @@ -389,7 +389,7 @@ public void testInvalidFilter() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "count(*) > 0")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "count(*) > 0")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessage("line 1:10: Row filter for 'local.tiny.orders' cannot contain aggregations, window functions or grouping operations: [count(*)]"); @@ -399,7 +399,7 @@ public void testInvalidFilter() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "row_number() OVER () > 0")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "row_number() OVER () > 0")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessage("line 1:22: Row filter for 'local.tiny.orders' cannot contain aggregations, window functions or grouping operations: [row_number() OVER ()]"); @@ -409,7 +409,7 @@ public void testInvalidFilter() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "grouping(orderkey) = 0")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "grouping(orderkey) = 0")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessage("line 1:20: Row filter for 'local.tiny.orders' cannot contain aggregations, window functions or grouping operations: [GROUPING (orderkey)]"); @@ -422,7 +422,7 @@ public void testShowStats() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey = 0")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey = 0")); assertThat(assertions.query("SHOW STATS FOR (SELECT * FROM tiny.orders)")) .containsAll( @@ -442,7 +442,7 @@ public void testDelete() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "nationkey < 10")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey < 10")); // Within allowed row filter assertions.query("DELETE FROM mock.tiny.nation WHERE nationkey < 3") @@ -474,7 +474,7 @@ public void testMergeDelete() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "nationkey < 10")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey < 10")); // Within allowed row filter assertThatThrownBy(() -> assertions.query(""" @@ -507,7 +507,7 @@ public void testUpdate() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "nationkey < 10")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey < 10")); // Within allowed row filter assertThatThrownBy(() -> assertions.query("UPDATE mock.tiny.nation SET regionkey = regionkey * 2 WHERE nationkey < 3")) @@ -547,7 +547,7 @@ public void testMergeUpdate() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "nationkey < 10")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey < 10")); // Within allowed row filter assertThatThrownBy(() -> assertions.query(""" @@ -604,7 +604,7 @@ public void testInsert() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "nationkey > 100")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey > 100")); // Within allowed row filter assertions.query("INSERT INTO mock.tiny.nation VALUES (101, 'POLAND', 0, 'No comment')") @@ -635,7 +635,7 @@ public void testMergeInsert() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "nationkey > 100")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey > 100")); // Within allowed row filter assertThatThrownBy(() -> assertions.query(""" @@ -670,7 +670,7 @@ public void testRowFilterWithHiddenColumns() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation_with_hidden_column"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "nationkey < 1")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey < 1")); assertions.query("SELECT * FROM mock.tiny.nation_with_hidden_column") .assertThat() @@ -703,7 +703,7 @@ public void testRowFilterOnHiddenColumn() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation_with_hidden_column"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "\"$hidden\" < 1")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "\"$hidden\" < 1")); assertions.query("SELECT count(*) FROM mock.tiny.nation_with_hidden_column") .assertThat() @@ -730,7 +730,7 @@ public void testRowFilterOnOptionalColumn() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG_MISSING_COLUMNS, "tiny", "nation_with_optional_column"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "length(optional) > 2")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "length(optional) > 2")); assertions.query("INSERT INTO mockmissingcolumns.tiny.nation_with_optional_column(nationkey, name, regionkey, comment, optional) VALUES (0, 'POLAND', 0, 'No comment', 'some string')") .assertThat() diff --git a/core/trino-spi/src/main/java/io/trino/spi/security/ViewExpression.java b/core/trino-spi/src/main/java/io/trino/spi/security/ViewExpression.java index 7348717c3b3b..7e280aa140e5 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/security/ViewExpression.java +++ b/core/trino-spi/src/main/java/io/trino/spi/security/ViewExpression.java @@ -19,12 +19,18 @@ public class ViewExpression { - private final String identity; + private final Optional identity; private final Optional catalog; private final Optional schema; private final String expression; + @Deprecated public ViewExpression(String identity, Optional catalog, Optional schema, String expression) + { + this(Optional.of(identity), catalog, schema, expression); + } + + public ViewExpression(Optional identity, Optional catalog, Optional schema, String expression) { this.identity = requireNonNull(identity, "identity is null"); this.catalog = requireNonNull(catalog, "catalog is null"); @@ -36,7 +42,17 @@ public ViewExpression(String identity, Optional catalog, Optional getSecurityIdentity() { return identity; } diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/CatalogTableAccessControlRule.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/CatalogTableAccessControlRule.java index 07ddf5b19040..71dc5e284a28 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/CatalogTableAccessControlRule.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/CatalogTableAccessControlRule.java @@ -79,14 +79,14 @@ public boolean canSelectColumns(Set columnNames) return tableAccessControlRule.canSelectColumns(columnNames); } - public Optional getColumnMask(String user, String catalog, String schema, String column) + public Optional getColumnMask(String catalog, String schema, String column) { - return tableAccessControlRule.getColumnMask(user, catalog, schema, column); + return tableAccessControlRule.getColumnMask(catalog, schema, column); } - public Optional getFilter(String user, String catalog, String schema) + public Optional getFilter(String catalog, String schema) { - return tableAccessControlRule.getFilter(user, catalog, schema); + return tableAccessControlRule.getFilter(catalog, schema); } Optional toAnyCatalogPermissionsRule() diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java index c4ce581c83cb..590744dd89cc 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java @@ -647,7 +647,7 @@ public List getRowFilters(ConnectorSecurityContext context, Sche ConnectorIdentity identity = context.getIdentity(); return tableRules.stream() .filter(rule -> rule.matches(identity.getUser(), identity.getEnabledSystemRoles(), identity.getGroups(), tableName)) - .map(rule -> rule.getFilter(identity.getUser(), catalogName, tableName.getSchemaName())) + .map(rule -> rule.getFilter(catalogName, tableName.getSchemaName())) // we return the first one we find .findFirst() .stream() @@ -665,7 +665,7 @@ public Optional getColumnMask(ConnectorSecurityContext context, ConnectorIdentity identity = context.getIdentity(); List masks = tableRules.stream() .filter(rule -> rule.matches(identity.getUser(), identity.getEnabledSystemRoles(), identity.getGroups(), tableName)) - .map(rule -> rule.getColumnMask(identity.getUser(), catalogName, tableName.getSchemaName(), columnName)) + .map(rule -> rule.getColumnMask(catalogName, tableName.getSchemaName(), columnName)) // we return the first one we find .findFirst() .stream() diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java index e1cfac1da06c..20c5d5d4a9bf 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java @@ -965,7 +965,7 @@ public List getRowFilters(SystemSecurityContext context, Catalog Identity identity = context.getIdentity(); return tableRules.stream() .filter(rule -> rule.matches(identity.getUser(), identity.getEnabledRoles(), identity.getGroups(), table)) - .map(rule -> rule.getFilter(identity.getUser(), table.getCatalogName(), tableName.getSchemaName())) + .map(rule -> rule.getFilter(table.getCatalogName(), tableName.getSchemaName())) // we return the first one we find .findFirst() .stream() @@ -984,7 +984,7 @@ public Optional getColumnMask(SystemSecurityContext context, Cat Identity identity = context.getIdentity(); List masks = tableRules.stream() .filter(rule -> rule.matches(identity.getUser(), identity.getEnabledRoles(), identity.getGroups(), table)) - .map(rule -> rule.getColumnMask(identity.getUser(), table.getCatalogName(), table.getSchemaTableName().getSchemaName(), columnName)) + .map(rule -> rule.getColumnMask(table.getCatalogName(), table.getSchemaTableName().getSchemaName(), columnName)) // we return the first one we find .findFirst() .stream() diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/TableAccessControlRule.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/TableAccessControlRule.java index e2d2e3f19ce5..13e6b5536e58 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/TableAccessControlRule.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/TableAccessControlRule.java @@ -102,20 +102,20 @@ public boolean canSelectColumns(Set columnNames) return (privileges.contains(SELECT) || privileges.contains(GRANT_SELECT)) && restrictedColumns.stream().noneMatch(columnNames::contains); } - public Optional getColumnMask(String user, String catalog, String schema, String column) + public Optional getColumnMask(String catalog, String schema, String column) { return Optional.ofNullable(columnConstraints.get(column)).flatMap(constraint -> constraint.getMask().map(mask -> new ViewExpression( - constraint.getMaskEnvironment().flatMap(ExpressionEnvironment::getUser).orElse(user), + constraint.getMaskEnvironment().flatMap(ExpressionEnvironment::getUser), Optional.of(catalog), Optional.of(schema), mask))); } - public Optional getFilter(String user, String catalog, String schema) + public Optional getFilter(String catalog, String schema) { return filter.map(filter -> new ViewExpression( - filterEnvironment.flatMap(ExpressionEnvironment::getUser).orElse(user), + filterEnvironment.flatMap(ExpressionEnvironment::getUser), Optional.of(catalog), Optional.of(schema), filter)); diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedConnectorAccessControlTest.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedConnectorAccessControlTest.java index d155baaced98..9b5a568beb5b 100644 --- a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedConnectorAccessControlTest.java +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedConnectorAccessControlTest.java @@ -419,7 +419,7 @@ public void testTableRulesForMixedGroupUsers() accessControl.checkCanSelectFromColumns(userGroup2, myTable, ImmutableSet.of()); assertViewExpressionEquals( accessControl.getColumnMask(userGroup2, myTable, "col_a", VARCHAR).orElseThrow(), - new ViewExpression(userGroup2.getIdentity().getUser(), Optional.of("test_catalog"), Optional.of("my_schema"), "'mask_a'")); + new ViewExpression(Optional.empty(), Optional.of("test_catalog"), Optional.of("my_schema"), "'mask_a'")); assertEquals( accessControl.getRowFilters(userGroup2, myTable), ImmutableList.of()); @@ -443,18 +443,18 @@ public void testTableRulesForMixedGroupUsers() accessControl.checkCanSelectFromColumns(userGroup3, myTable, ImmutableSet.of()); assertViewExpressionEquals( accessControl.getColumnMask(userGroup3, myTable, "col_a", VARCHAR).orElseThrow(), - new ViewExpression(userGroup3.getIdentity().getUser(), Optional.of("test_catalog"), Optional.of("my_schema"), "'mask_a'")); + new ViewExpression(Optional.empty(), Optional.of("test_catalog"), Optional.of("my_schema"), "'mask_a'")); List rowFilters = accessControl.getRowFilters(userGroup3, myTable); assertEquals(rowFilters.size(), 1); assertViewExpressionEquals( rowFilters.get(0), - new ViewExpression(userGroup3.getIdentity().getUser(), Optional.of("test_catalog"), Optional.of("my_schema"), "country='US'")); + new ViewExpression(Optional.empty(), Optional.of("test_catalog"), Optional.of("my_schema"), "country='US'")); } private static void assertViewExpressionEquals(ViewExpression actual, ViewExpression expected) { - assertEquals(actual.getIdentity(), expected.getIdentity(), "Identity"); + assertEquals(actual.getSecurityIdentity(), expected.getSecurityIdentity(), "Identity"); assertEquals(actual.getCatalog(), expected.getCatalog(), "Catalog"); assertEquals(actual.getSchema(), expected.getSchema(), "Schema"); assertEquals(actual.getExpression(), expected.getExpression(), "Expression"); diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedSystemAccessControlTest.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedSystemAccessControlTest.java index 3b687a4db496..2fa79c6a4bf7 100644 --- a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedSystemAccessControlTest.java +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedSystemAccessControlTest.java @@ -802,7 +802,7 @@ public void testTableRulesForMixedGroupUsers() new CatalogSchemaTableName("some-catalog", "my_schema", "my_table"), "col_a", VARCHAR).orElseThrow(), - new ViewExpression(userGroup2.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("my_schema"), "'mask_a'")); + new ViewExpression(Optional.empty(), Optional.of("some-catalog"), Optional.of("my_schema"), "'mask_a'")); SystemSecurityContext userGroup1Group3 = new SystemSecurityContext(Identity.forUser("user_1_3") .withGroups(ImmutableSet.of("group1", "group3")).build(), Optional.empty()); @@ -821,7 +821,7 @@ public void testTableRulesForMixedGroupUsers() assertEquals(rowFilters.size(), 1); assertViewExpressionEquals( rowFilters.get(0), - new ViewExpression(userGroup3.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("my_schema"), "country='US'")); + new ViewExpression(Optional.empty(), Optional.of("some-catalog"), Optional.of("my_schema"), "country='US'")); } @Test @@ -1425,7 +1425,7 @@ public void testGetColumnMask() new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns"), "masked", VARCHAR).orElseThrow(), - new ViewExpression(CHARLIE.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("bobschema"), "'mask'")); + new ViewExpression(Optional.empty(), Optional.of("some-catalog"), Optional.of("bobschema"), "'mask'")); assertViewExpressionEquals( accessControl.getColumnMask( @@ -1433,7 +1433,7 @@ public void testGetColumnMask() new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns"), "masked_with_user", VARCHAR).orElseThrow(), - new ViewExpression("mask-user", Optional.of("some-catalog"), Optional.of("bobschema"), "'mask-with-user'")); + new ViewExpression(Optional.of("mask-user"), Optional.of("some-catalog"), Optional.of("bobschema"), "'mask-with-user'")); } @Test @@ -1449,18 +1449,18 @@ public void testGetRowFilter() assertEquals(rowFilters.size(), 1); assertViewExpressionEquals( rowFilters.get(0), - new ViewExpression(CHARLIE.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("bobschema"), "starts_with(value, 'filter')")); + new ViewExpression(Optional.empty(), Optional.of("some-catalog"), Optional.of("bobschema"), "starts_with(value, 'filter')")); rowFilters = accessControl.getRowFilters(CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns_with_grant")); assertEquals(rowFilters.size(), 1); assertViewExpressionEquals( rowFilters.get(0), - new ViewExpression("filter-user", Optional.of("some-catalog"), Optional.of("bobschema"), "starts_with(value, 'filter-with-user')")); + new ViewExpression(Optional.of("filter-user"), Optional.of("some-catalog"), Optional.of("bobschema"), "starts_with(value, 'filter-with-user')")); } private static void assertViewExpressionEquals(ViewExpression actual, ViewExpression expected) { - assertEquals(actual.getIdentity(), expected.getIdentity(), "Identity"); + assertEquals(actual.getSecurityIdentity(), expected.getSecurityIdentity(), "Identity"); assertEquals(actual.getCatalog(), expected.getCatalog(), "Catalog"); assertEquals(actual.getSchema(), expected.getSchema(), "Schema"); assertEquals(actual.getExpression(), expected.getExpression(), "Expression"); diff --git a/testing/trino-tests/src/test/java/io/trino/security/TestAccessControl.java b/testing/trino-tests/src/test/java/io/trino/security/TestAccessControl.java index 3345d6fc63ad..9dde277210bb 100644 --- a/testing/trino-tests/src/test/java/io/trino/security/TestAccessControl.java +++ b/testing/trino-tests/src/test/java/io/trino/security/TestAccessControl.java @@ -122,6 +122,7 @@ protected QueryRunner createQueryRunner() throws Exception { Session session = testSessionBuilder() + .setSource("test") .setCatalog("blackhole") .setSchema("default") .build(); @@ -969,7 +970,7 @@ public void testAccessControlWithGroupsAndColumnMask() new QualifiedObjectName("blackhole", "default", "orders"), "comment", getSession().getUser(), - new ViewExpression(getSession().getUser(), Optional.empty(), Optional.empty(), "substr(comment,1,3)")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "substr(comment,1,3)")); assertAccessAllowed("SELECT comment FROM orders"); } @@ -983,11 +984,54 @@ public void testAccessControlWithGroupsAndRowFilter() accessControlManager.rowFilter( new QualifiedObjectName("blackhole", "default", "nation"), getSession().getUser(), - new ViewExpression(getSession().getUser(), Optional.empty(), Optional.empty(), "nationkey % 2 = 0")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey % 2 = 0")); assertAccessAllowed("SELECT nationkey FROM nation"); } + @Test + public void testAccessControlWithRolesAndColumnMask() + { + String role = "role"; + String user = "user"; + Session session = Session.builder(getSession()) + .setIdentity(Identity.forUser(user) + .withEnabledRoles(ImmutableSet.of(role)) + .build()) + .build(); + systemSecurityMetadata.grantRoles(getSession(), Set.of(role), Set.of(new TrinoPrincipal(USER, user)), false, Optional.empty()); + TestingAccessControlManager accessControlManager = getQueryRunner().getAccessControl(); + accessControlManager.denyIdentityTable((identity, table) -> (identity.getEnabledRoles().contains(role) && "orders".equals(table))); + accessControlManager.columnMask( + new QualifiedObjectName("blackhole", "default", "orders"), + "comment", + getSession().getUser(), + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "substr(comment,1,3)")); + + assertAccessAllowed(session, "SELECT comment FROM orders"); + } + + @Test + public void testAccessControlWithRolesAndRowFilter() + { + String role = "role"; + String user = "user"; + Session session = Session.builder(getSession()) + .setIdentity(Identity.forUser(user) + .withEnabledRoles(ImmutableSet.of(role)) + .build()) + .build(); + systemSecurityMetadata.grantRoles(getSession(), Set.of(role), Set.of(new TrinoPrincipal(USER, user)), false, Optional.empty()); + TestingAccessControlManager accessControlManager = getQueryRunner().getAccessControl(); + accessControlManager.denyIdentityTable((identity, table) -> (identity.getEnabledRoles().contains(role) && "nation".equals(table))); + accessControlManager.rowFilter( + new QualifiedObjectName("blackhole", "default", "nation"), + getSession().getUser(), + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey % 2 = 0")); + + assertAccessAllowed(session, "SELECT nationkey FROM nation"); + } + private static final class DenySetPropertiesSystemAccessControl extends AllowAllSystemAccessControl {