diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index c112292a9cc7..44713ddf4336 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -296,26 +296,23 @@ private RelationPlan addColumnMasks(Table table, RelationPlan plan) PlanBuilder planBuilder = newPlanBuilder(plan, analysis, lambdaDeclarationToSymbolMap) .withScope(analysis.getAccessControlScope(table), plan.getFieldMappings()); // The fields in the access control scope has the same layout as those for the table scope + Map assignments = new LinkedHashMap<>(); + for (Symbol symbol : planBuilder.getRoot().getOutputSymbols()) { + assignments.put(symbol, symbol.toSymbolReference()); + } for (int i = 0; i < plan.getDescriptor().getAllFieldCount(); i++) { Field field = plan.getDescriptor().getFieldByIndex(i); - for (Expression mask : columnMasks.getOrDefault(field.getName().orElseThrow(), ImmutableList.of())) { planBuilder = subqueryPlanner.handleSubqueries(planBuilder, mask, analysis.getSubqueries(mask)); - - Map assignments = new LinkedHashMap<>(); - for (Symbol symbol : planBuilder.getRoot().getOutputSymbols()) { - assignments.put(symbol, symbol.toSymbolReference()); - } assignments.put(plan.getFieldMappings().get(i), coerceIfNecessary(analysis, mask, planBuilder.rewrite(mask))); - - planBuilder = planBuilder - .withNewRoot(new ProjectNode( - idAllocator.getNextId(), - planBuilder.getRoot(), - Assignments.copyOf(assignments))); } } + planBuilder = planBuilder + .withNewRoot(new ProjectNode( + idAllocator.getNextId(), + planBuilder.getRoot(), + Assignments.copyOf(assignments))); return new RelationPlan(planBuilder.getRoot(), plan.getScope(), plan.getFieldMappings(), outerContext); } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java b/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java index e42a7a647563..f6e8084a0627 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java @@ -180,7 +180,7 @@ public void testSimpleMask() } @Test - public void testMultipleMasksOnSameColumn() + public void testMultipleMasksOnDifferentColumns() { accessControl.reset(); accessControl.columnMask( @@ -191,31 +191,134 @@ public void testMultipleMasksOnSameColumn() accessControl.columnMask( new QualifiedObjectName(CATALOG, "tiny", "orders"), - "custkey", + "orderstatus", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "custkey * 2")); + new ViewExpression(USER, Optional.empty(), Optional.empty(), "'X'")); - assertThat(assertions.query("SELECT custkey FROM orders WHERE orderkey = 1")).matches("VALUES BIGINT '-740'"); + assertThat(assertions.query("SELECT custkey, orderstatus FROM orders WHERE orderkey = 1")) + .matches("VALUES (BIGINT '-370', 'X')"); } @Test - public void testMultipleMasksOnDifferentColumns() + public void testMultipleMasksUsingOtherMaskedColumns() { + String query = "SELECT comment, orderstatus, clerk FROM orders WHERE orderkey = 1"; + String expected = "VALUES (CAST('nstructions sleep furiously among ' as varchar(79)), 'O', 'Clerk#000000951')"; + + accessControl.reset(); + assertThat(assertions.query(query)).matches(expected); + + // mask "clerk" and "orderstatus" using "comment" ("comment" appears after both in table definition) + // columns will not be masked since rules are not apply to row values accessControl.reset(); accessControl.columnMask( new QualifiedObjectName(CATALOG, "tiny", "orders"), - "custkey", + "comment", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "-custkey")); + new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast(regexp_replace(comment,'(password: [^ ]+)','password: ****') as varchar(79))")); accessControl.columnMask( new QualifiedObjectName(CATALOG, "tiny", "orders"), "orderstatus", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "'X'")); + new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(comment,'(country: [^ ]+)') IN ('country: 1'), '*', orderstatus)")); - assertThat(assertions.query("SELECT custkey, orderstatus FROM orders WHERE orderkey = 1")) - .matches("VALUES (BIGINT '-370', 'X')"); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "clerk", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(comment,'(country: [^ ]+)') IN ('country: 1'), '***', clerk)")); + + assertThat(assertions.query(query)).matches(expected); + + // mask "comment" using "clerk" ("clerk" column appears before "comment" in table definition) + // columns will not be masked since rules are not apply to row values + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "clerk", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(password: [^ ]+)','password: ****') as varchar(15))")); + + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "comment", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'(country: [^ ]+)') IN ('country: 1'), '***', comment)")); + + assertThat(assertions.query(query)).matches(expected); + + // now add mask for "orderstatus" using "clerk" + // columns will not be masked since rules are not apply to row values + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "clerk", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(password: [^ ]+)','password: ****') as varchar(15))")); + + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "orderstatus", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'(country: [^ ]+)') IN ('country: 1'), '*', orderstatus)")); + + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "comment", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'(country: [^ ]+)') IN ('country: 1'), '***', comment)")); + + assertThat(assertions.query(query)).matches(expected); + + query = "SELECT comment, orderstatus, clerk FROM orders WHERE orderkey = 39"; + + accessControl.reset(); + assertThat(assertions.query(query)) + .matches("VALUES (CAST('ole express, ironic requests: ir' as varchar(79)), 'O', 'Clerk#000000659')"); + + // mask "comment" and "orderstatus" using "clerk" ("clerk" appears before "comment" table definition) + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "clerk", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(Clerk#)','***#') as varchar(15))")); + + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "comment", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('659'), '***', comment)")); + + assertThat(assertions.query(query)) + .matches("VALUES (CAST('***' as varchar(79)), 'O', CAST('***#000000659' as varchar(15)))"); + + assertThat(assertions.query("SELECT comment, orderstatus, clerk, length(clerk) FROM orders WHERE orderkey = 39")) + .matches("VALUES (CAST('***' as varchar(79)), 'O', CAST('***#000000659' as varchar(15)), bigint'13')"); + + // mask "comment" and "orderstatus" using "clerk" ("clerk" appears before "comment" table definition) + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "clerk", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(Clerk#)','***#') as varchar(15))")); + + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "orderstatus", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('659'), '*', orderstatus)")); + + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "comment", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('659'), '***', comment)")); + + assertThat(assertions.query(query)) + .matches("VALUES (CAST('***' as varchar(79)), '*', CAST('***#000000659' as varchar(15)))"); } @Test