diff --git a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/TestAccumuloDistributedQueries.java b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/TestAccumuloDistributedQueries.java index 3ab7c8673e4c8..38854590264c0 100644 --- a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/TestAccumuloDistributedQueries.java +++ b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/TestAccumuloDistributedQueries.java @@ -100,6 +100,13 @@ public void testCreateTableAsSelect() "SELECT 0"); } + @Test + @Override + public void testSubfieldAccessControl() + { + // disabled as accumulo doesn't support complex types + } + @Override public void testDelete() { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java index b6ce743c0c789..5a1ad93366ef4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java @@ -768,7 +768,7 @@ public JoinUsingAnalysis getJoinUsing(Join node) return joinUsing.get(NodeRef.of(node)); } - public void addTableColumnAndSubfieldReferences(AccessControl accessControl, Identity identity, Multimap tableColumnMap) + public void addTableColumnAndSubfieldReferences(AccessControl accessControl, Identity identity, Multimap tableColumnMap, Multimap tableColumnMapForAccessControl) { AccessControlInfo accessControlInfo = new AccessControlInfo(accessControl, identity); Map> columnReferences = tableColumnReferences.computeIfAbsent(accessControlInfo, k -> new LinkedHashMap<>()); @@ -776,7 +776,7 @@ public void addTableColumnAndSubfieldReferences(AccessControl accessControl, Ide .forEach((key, value) -> columnReferences.computeIfAbsent(key, k -> new HashSet<>()).addAll(value.stream().map(Subfield::getRootName).collect(toImmutableSet()))); Map> columnAndSubfieldReferences = tableColumnAndSubfieldReferences.computeIfAbsent(accessControlInfo, k -> new LinkedHashMap<>()); - tableColumnMap.asMap() + tableColumnMapForAccessControl.asMap() .forEach((key, value) -> columnAndSubfieldReferences.computeIfAbsent(key, k -> new HashSet<>()).addAll(value)); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java index fde13501f69cf..31d0d4e8b291c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java @@ -109,7 +109,9 @@ import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Multimap; +import com.google.common.collect.Sets; import io.airlift.slice.SliceUtf8; import javax.annotation.Nullable; @@ -124,8 +126,6 @@ import java.util.Set; import java.util.function.Function; -import static com.facebook.presto.common.Subfield.NestedField; -import static com.facebook.presto.common.Subfield.PathElement; import static com.facebook.presto.common.function.OperatorType.SUBSCRIPT; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; @@ -153,6 +153,10 @@ import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoExternalFunctions; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.isNonNullConstant; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.tryResolveEnumLiteralType; +import static com.facebook.presto.sql.analyzer.FunctionArgumentCheckerForAccessControlUtils.getResolvedLambdaArguments; +import static com.facebook.presto.sql.analyzer.FunctionArgumentCheckerForAccessControlUtils.isTopMostReference; +import static com.facebook.presto.sql.analyzer.FunctionArgumentCheckerForAccessControlUtils.isUnusedArgumentForAccessControl; +import static com.facebook.presto.sql.analyzer.FunctionArgumentCheckerForAccessControlUtils.resolveSubfield; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.EXPRESSION_NOT_CONSTANT; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_LITERAL; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE; @@ -179,7 +183,6 @@ import static java.lang.Math.toIntExact; import static java.lang.String.format; import static java.util.Collections.emptyMap; -import static java.util.Collections.reverse; import static java.util.Collections.unmodifiableMap; import static java.util.Collections.unmodifiableSet; import static java.util.Objects.requireNonNull; @@ -207,6 +210,7 @@ public class ExpressionAnalyzer private final Map, LambdaArgumentDeclaration> lambdaArgumentReferences = new LinkedHashMap<>(); private final Set> windowFunctions = new LinkedHashSet<>(); private final Multimap tableColumnAndSubfieldReferences = HashMultimap.create(); + private final Multimap tableColumnAndSubfieldReferencesForAccessControl = HashMultimap.create(); private final Optional transactionId; private final Optional> sessionFunctions; @@ -327,6 +331,11 @@ public Multimap getTableColumnAndSubfieldReferenc return tableColumnAndSubfieldReferences; } + public Multimap getTableColumnAndSubfieldReferencesForAccessControl() + { + return tableColumnAndSubfieldReferencesForAccessControl; + } + private class Visitor extends StackableAstVisitor { @@ -427,13 +436,27 @@ private Type handleResolvedField(Expression node, FieldId fieldId, Field field, if (lambdaArgumentDeclaration != null) { // Lambda argument reference is not a column reference lambdaArgumentReferences.put(NodeRef.of((Identifier) node), lambdaArgumentDeclaration); - return setExpressionType(node, field.getType()); + if (!context.getContext().getResolvedLambdaArguments().containsKey(node)) { + return setExpressionType(node, field.getType()); + } } } // If we found a direct column reference, and we will put it in tableColumnReferencesWithSubFields - if (field.getOriginTable().isPresent() && field.getOriginColumnName().isPresent() && isTopMostReference(node, context)) { - tableColumnAndSubfieldReferences.put(field.getOriginTable().get(), new Subfield(field.getOriginColumnName().get(), ImmutableList.of())); + if (isTopMostReference(node, context)) { + Optional tableName = field.getOriginTable(); + Optional subfield = field.getOriginColumnName().map(column -> new Subfield(column, ImmutableList.of())); + ResolvedSubfield resolvedSubfield = context.getContext().getResolvedLambdaArguments().get(node); + if (resolvedSubfield != null) { + tableName = resolvedSubfield.getResolvedField().getField().getOriginTable(); + subfield = Optional.of(resolvedSubfield.getSubfield()); + } + if (tableName.isPresent() && subfield.isPresent()) { + tableColumnAndSubfieldReferences.put(tableName.get(), subfield.get()); + if (!context.getContext().getUnusedExpressionsForAccessControl().contains(NodeRef.of(node))) { + tableColumnAndSubfieldReferencesForAccessControl.put(tableName.get(), subfield.get()); + } + } } FieldId previous = columnReferences.put(NodeRef.of(node), fieldId); @@ -923,7 +946,8 @@ protected Type visitFunctionCall(FunctionCall node, StackableAstVisitorContext argumentTypesBuilder = ImmutableList.builder(); - for (Expression expression : node.getArguments()) { + for (int index = 0; index < node.getArguments().size(); ++index) { + Expression expression = node.getArguments().get(index); if (expression instanceof LambdaExpression || expression instanceof BindExpression) { argumentTypesBuilder.add(new TypeSignatureProvider( types -> { @@ -942,7 +966,7 @@ protected Type visitFunctionCall(FunctionCall node, StackableAstVisitorContext(newContext)).getTypeSignature())); } } @@ -968,6 +996,8 @@ protected Type visitFunctionCall(FunctionCall node, StackableAstVisitorContext resolvedLambdaArguments = getResolvedLambdaArguments(node, context, expressionTypes); + for (int i = 0; i < node.getArguments().size(); i++) { Expression expression = node.getArguments().get(i); Type expectedType = functionAndTypeManager.getType(functionMetadata.getArgumentTypes().get(i)); @@ -977,7 +1007,7 @@ protected Type visitFunctionCall(FunctionCall node, StackableAstVisitorContext(context.getContext().expectingLambda(expectedFunctionType.getArgumentTypes()))); + process(expression, new StackableAstVisitorContext<>(context.getContext().expectingLambda(expectedFunctionType.getArgumentTypes(), resolvedLambdaArguments))); } else { Type actualType = functionAndTypeManager.getType(argumentTypes.get(i).getTypeSignature()); @@ -1283,7 +1313,7 @@ protected Type visitLambdaExpression(LambdaExpression node, StackableAstVisitorC fieldToLambdaArgumentDeclaration.put(FieldId.from(resolvedField), lambdaArgument); } - Type returnType = process(node.getBody(), new StackableAstVisitorContext<>(Context.inLambda(lambdaScope, fieldToLambdaArgumentDeclaration.build()))); + Type returnType = process(node.getBody(), new StackableAstVisitorContext<>(context.getContext().inLambda(lambdaScope, fieldToLambdaArgumentDeclaration.build()))); FunctionType functionType = new FunctionType(types, returnType); return setExpressionType(node, functionType); } @@ -1301,7 +1331,7 @@ protected Type visitBindExpression(BindExpression node, StackableAstVisitorConte functionInputTypesBuilder.addAll(context.getContext().getFunctionInputTypes()); List functionInputTypes = functionInputTypesBuilder.build(); - FunctionType functionType = (FunctionType) process(node.getFunction(), new StackableAstVisitorContext<>(context.getContext().expectingLambda(functionInputTypes))); + FunctionType functionType = (FunctionType) process(node.getFunction(), new StackableAstVisitorContext<>(context.getContext().expectingLambda(functionInputTypes, ImmutableMap.of()))); List argumentTypes = functionType.getArgumentTypes(); int numCapturedValues = node.getValues().size(); @@ -1345,74 +1375,20 @@ public Type visitGroupingOperation(GroupingOperation node, StackableAstVisitorCo } } - private boolean isDereferenceOrSubscript(Expression node) - { - return node instanceof DereferenceExpression || node instanceof SubscriptExpression; - } - - private boolean isTopMostReference(Expression node, StackableAstVisitorContext context) - { - if (!context.getPreviousNode().isPresent()) { - return true; - } - return !isDereferenceOrSubscript((Expression) context.getPreviousNode().get()); - } - private void addColumnSubfieldReferences(Expression node, StackableAstVisitorContext context) { - // If expression is nested with multiple dereferences and subscripts, we only look at the topmost one. - if (!isTopMostReference(node, context)) { + Optional resolvedSubfield = resolveSubfield(node, context, expressionTypes); + if (!resolvedSubfield.isPresent()) { return; } - Scope scope = context.getContext().getScope(); - Expression childNode = node; - List columnDereferences = new ArrayList<>(); - while (true) { - if (childNode instanceof SubscriptExpression) { - SubscriptExpression subscriptExpression = (SubscriptExpression) childNode; - childNode = subscriptExpression.getBase(); - Type baseType = expressionTypes.get(NodeRef.of(childNode)); - if (baseType == null || !(baseType instanceof RowType)) { - continue; - } - int index = toIntExact(((LongLiteral) subscriptExpression.getIndex()).getValue()); - RowType baseRowType = (RowType) baseType; - Optional dereference = baseRowType.getFields().get(index - 1).getName(); - if (!dereference.isPresent()) { - break; - } - columnDereferences.add(new NestedField(dereference.get())); - continue; - } + tableColumnAndSubfieldReferences.put( + resolvedSubfield.get().getResolvedField().getField().getOriginTable().get(), + resolvedSubfield.get().getSubfield()); - QualifiedName childQualifiedName; - if (childNode instanceof DereferenceExpression) { - childQualifiedName = DereferenceExpression.getQualifiedName((DereferenceExpression) childNode); - } - else if (childNode instanceof Identifier) { - childQualifiedName = QualifiedName.of(((Identifier) childNode).getValue()); - } - else { - break; - } - if (childQualifiedName != null) { - Optional resolvedField = scope.tryResolveField(childNode, childQualifiedName); - if (resolvedField.isPresent() && - resolvedField.get().getField().getOriginColumnName().isPresent() && - resolvedField.get().getField().getOriginTable().isPresent()) { - reverse(columnDereferences); - tableColumnAndSubfieldReferences.put( - resolvedField.get().getField().getOriginTable().get(), - new Subfield(resolvedField.get().getField().getOriginColumnName().get(), columnDereferences)); - break; - } - } - if (childNode instanceof DereferenceExpression) { - columnDereferences.add(new NestedField(((DereferenceExpression) childNode).getField().getValue())); - childNode = ((DereferenceExpression) childNode).getBase(); - continue; - } - break; + if (!context.getContext().getUnusedExpressionsForAccessControl().contains(NodeRef.of(node))) { + tableColumnAndSubfieldReferencesForAccessControl.put( + resolvedSubfield.get().getResolvedField().getField().getOriginTable().get(), + resolvedSubfield.get().getSubfield()); } } @@ -1536,7 +1512,7 @@ else if (typeOnlyCoercions.contains(ref)) { } } - private static class Context + public static class Context { private final Scope scope; @@ -1548,35 +1524,61 @@ private static class Context // The mapping from names to corresponding lambda argument declarations when inside a lambda; null otherwise. // Empty map means that the all lambda expressions surrounding the current node has no arguments. private final Map fieldToLambdaArgumentDeclaration; + private final Map resolvedLambdaArguments; + private final Set> unusedExpressionsForAccessControl; private Context( Scope scope, List functionInputTypes, - Map fieldToLambdaArgumentDeclaration) + Map fieldToLambdaArgumentDeclaration, + Map resolvedLambdaArguments, + Set> unusedExpressionsForAccessControl) { this.scope = requireNonNull(scope, "scope is null"); this.functionInputTypes = functionInputTypes; this.fieldToLambdaArgumentDeclaration = fieldToLambdaArgumentDeclaration; + this.resolvedLambdaArguments = requireNonNull(resolvedLambdaArguments, "resolvedLambdaArguments is null"); + this.unusedExpressionsForAccessControl = requireNonNull(unusedExpressionsForAccessControl, "unusedExpressionsForAccessControl is null"); } public static Context notInLambda(Scope scope) { - return new Context(scope, null, null); + return new Context(scope, null, null, ImmutableMap.of(), ImmutableSet.of()); } - public static Context inLambda(Scope scope, Map fieldToLambdaArgumentDeclaration) + public Context inLambda(Scope scope, Map fieldToLambdaArgumentDeclaration) { - return new Context(scope, null, requireNonNull(fieldToLambdaArgumentDeclaration, "fieldToLambdaArgumentDeclaration is null")); + return new Context( + scope, + null, + requireNonNull(fieldToLambdaArgumentDeclaration, "fieldToLambdaArgumentDeclaration is null"), + resolvedLambdaArguments, + unusedExpressionsForAccessControl); } - public Context expectingLambda(List functionInputTypes) + public Context expectingLambda(List functionInputTypes, Map resolvedLambdaArguments) { - return new Context(scope, requireNonNull(functionInputTypes, "functionInputTypes is null"), this.fieldToLambdaArgumentDeclaration); + return new Context( + scope, + requireNonNull(functionInputTypes, "functionInputTypes is null"), + this.fieldToLambdaArgumentDeclaration, + resolvedLambdaArguments, + unusedExpressionsForAccessControl); } public Context notExpectingLambda() { - return new Context(scope, null, this.fieldToLambdaArgumentDeclaration); + return new Context(scope, null, this.fieldToLambdaArgumentDeclaration, ImmutableMap.of(), unusedExpressionsForAccessControl); + } + + public Context withUnusedExpressionsForAccessControl(Set> unusedExpressionsForAccessControl) + { + return new Context( + scope, + functionInputTypes, + fieldToLambdaArgumentDeclaration, + resolvedLambdaArguments, + Sets.union(unusedExpressionsForAccessControl, this.unusedExpressionsForAccessControl)); } Scope getScope() @@ -1605,6 +1607,16 @@ public List getFunctionInputTypes() checkState(isExpectingLambda()); return functionInputTypes; } + + public Map getResolvedLambdaArguments() + { + return resolvedLambdaArguments; + } + + public Set> getUnusedExpressionsForAccessControl() + { + return unusedExpressionsForAccessControl; + } } public static FunctionHandle resolveFunction( @@ -1720,7 +1732,7 @@ public static ExpressionAnalysis analyzeExpression( analysis.addFunctionHandles(resolvedFunctions); analysis.addColumnReferences(analyzer.getColumnReferences()); analysis.addLambdaArgumentReferences(analyzer.getLambdaArgumentReferences()); - analysis.addTableColumnAndSubfieldReferences(accessControl, session.getIdentity(), analyzer.getTableColumnAndSubfieldReferences()); + analysis.addTableColumnAndSubfieldReferences(accessControl, session.getIdentity(), analyzer.getTableColumnAndSubfieldReferences(), analyzer.getTableColumnAndSubfieldReferencesForAccessControl()); return new ExpressionAnalysis( expressionTypes, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FunctionArgumentCheckerForAccessControlUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FunctionArgumentCheckerForAccessControlUtils.java new file mode 100644 index 0000000000000..464e178d2098b --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FunctionArgumentCheckerForAccessControlUtils.java @@ -0,0 +1,181 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.analyzer; + +import com.facebook.presto.common.Subfield; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.sql.tree.DereferenceExpression; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.Identifier; +import com.facebook.presto.sql.tree.LambdaExpression; +import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.NodeRef; +import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.SubscriptExpression; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.sql.tree.StackableAstVisitor.StackableAstVisitorContext; +import static com.google.common.base.Preconditions.checkState; +import static java.lang.Math.toIntExact; +import static java.util.Collections.reverse; + +public class FunctionArgumentCheckerForAccessControlUtils +{ + private static final QualifiedName TRANSFORM = QualifiedName.of("transform"); + private static final QualifiedName CARDINALITY = QualifiedName.of("cardinality"); + + private FunctionArgumentCheckerForAccessControlUtils() {} + + // Returns whether function argument at `argumentIndex` for function `node` needs to be checked + // for column level access control. + // For e.g., consider SQL `transform(arr, col -> col.x)` + // Here, we only need to check for access of subfield `x` in column `arr` which is of type `Array`. + // So we can just parse lambda and ignore the first argument for access checks. + public static boolean isUnusedArgumentForAccessControl(FunctionCall node, int argumentIndex, ExpressionAnalyzer.Context context) + { + if (node.getName().equals(TRANSFORM)) { + checkState(node.getArguments().size() == 2); + return argumentIndex == 0; + } + if (node.getName().equals(CARDINALITY)) { + checkState(node.getArguments().size() == 1); + return argumentIndex == 0; + } + return false; + } + + // Parses arguments of function `node` which are a lambda expression, and returns a map + // of their lambda arguments to resolved subfield. + // For e.g., consider SQL `SELECT transform(arr, col -> col.x) FROM table` + // Return value = Map('col' -> ResolvedSubfield(table.arr)) + public static Map getResolvedLambdaArguments( + FunctionCall node, + StackableAstVisitorContext context, + Map, Type> expressionTypes) + { + ImmutableMap.Builder resolvedLambdaArguments = ImmutableMap.builder(); + if (node.getName().equals(TRANSFORM)) { + checkState(node.getArguments().size() == 2); + if (!(node.getArguments().get(1) instanceof LambdaExpression)) { + return ImmutableMap.of(); + } + Expression arrayExpression = node.getArguments().get(0); + LambdaExpression lambdaExpression = ((LambdaExpression) node.getArguments().get(1)); + Optional resolvedSubfield = resolveSubfield(arrayExpression, context, expressionTypes); + if (resolvedSubfield.isPresent()) { + resolvedLambdaArguments.put( + lambdaExpression.getArguments().get(0).getName(), + resolvedSubfield.get()); + } + } + return resolvedLambdaArguments.build(); + } + + public static Optional resolveSubfield( + Expression node, + StackableAstVisitorContext context, + Map, Type> expressionTypes) + { + // If expression is nested with multiple dereferences and subscripts, we only look at the topmost one. + if (!isTopMostReference(node, context)) { + return Optional.empty(); + } + + Scope scope = context.getContext().getScope(); + Expression childNode = node; + List columnDereferences = new ArrayList<>(); + while (true) { + // Dereference row/array/map expressions + if (childNode instanceof SubscriptExpression) { + SubscriptExpression subscriptExpression = (SubscriptExpression) childNode; + childNode = subscriptExpression.getBase(); + Type baseType = expressionTypes.get(NodeRef.of(childNode)); + if (baseType == null || !(baseType instanceof RowType)) { + continue; + } + int index = toIntExact(((LongLiteral) subscriptExpression.getIndex()).getValue()); + RowType baseRowType = (RowType) baseType; + Optional dereference = baseRowType.getFields().get(index - 1).getName(); + if (!dereference.isPresent()) { + break; + } + columnDereferences.add(new Subfield.NestedField(dereference.get())); + continue; + } + + QualifiedName childQualifiedName; + // Dereference subfield expressions + if (childNode instanceof DereferenceExpression) { + childQualifiedName = DereferenceExpression.getQualifiedName((DereferenceExpression) childNode); + } + // Base case + else if (childNode instanceof Identifier) { + childQualifiedName = QualifiedName.of(((Identifier) childNode).getValue()); + } + else { + break; + } + // If we found the full de-referenced expression, return it as a ResolvedSubfield + if (childQualifiedName != null) { + Optional resolvedField = scope.tryResolveField(childNode, childQualifiedName); + if (resolvedField.isPresent() && !resolvedField.get().getField().getOriginTable().isPresent()) { + // Try to resolve using lambda expressions + Optional resolvedSubField = Optional.ofNullable(context.getContext().getResolvedLambdaArguments().get(childNode)); + if (resolvedSubField.isPresent()) { + resolvedField = Optional.of(resolvedSubField.get().getResolvedField()); + columnDereferences.addAll(Lists.reverse(resolvedSubField.get().getSubfield().getPath())); + } + } + if (resolvedField.isPresent() && + resolvedField.get().getField().getOriginColumnName().isPresent() && + resolvedField.get().getField().getOriginTable().isPresent()) { + reverse(columnDereferences); + return Optional.of(new ResolvedSubfield( + resolvedField.get(), + new Subfield(resolvedField.get().getField().getOriginColumnName().get(), columnDereferences))); + } + } + // If we cannot resolve full de-referenced name, that means that there are + // more dereferences to be resolved, so we continue the while loop with new childNode. + if (childNode instanceof DereferenceExpression) { + columnDereferences.add(new Subfield.NestedField(((DereferenceExpression) childNode).getField().getValue())); + childNode = ((DereferenceExpression) childNode).getBase(); + continue; + } + break; + } + return Optional.empty(); + } + + public static boolean isDereferenceOrSubscript(Expression node) + { + return node instanceof DereferenceExpression || node instanceof SubscriptExpression; + } + + public static boolean isTopMostReference(Expression node, StackableAstVisitorContext context) + { + if (!context.getPreviousNode().isPresent()) { + return true; + } + return !isDereferenceOrSubscript((Expression) context.getPreviousNode().get()); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ResolvedSubfield.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ResolvedSubfield.java new file mode 100644 index 0000000000000..57d1995543ac7 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ResolvedSubfield.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.analyzer; + +import com.facebook.presto.common.Subfield; + +import static java.util.Objects.requireNonNull; + +public class ResolvedSubfield +{ + private final ResolvedField resolvedField; + private final Subfield subfield; + + public ResolvedSubfield(ResolvedField resolvedField, Subfield subfield) + { + this.resolvedField = requireNonNull(resolvedField, "resolvedField is null"); + this.subfield = requireNonNull(subfield, "subfield is null"); + } + + public ResolvedField getResolvedField() + { + return resolvedField; + } + + public Subfield getSubfield() + { + return subfield; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index 773c080fc3912..2fdcb79c8ffa7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -595,12 +595,10 @@ protected Scope visitAnalyze(Analyze node, Optional scope) .orElseThrow(() -> (new SemanticException(MISSING_TABLE, node, "Table '%s' does not exist", tableName))); // user must have read and insert permission in order to analyze stats of a table - analysis.addTableColumnAndSubfieldReferences( - accessControl, - session.getIdentity(), - ImmutableMultimap.builder() - .putAll(tableName, metadata.getColumnHandles(session, tableHandle).keySet().stream().map(column -> new Subfield(column, ImmutableList.of())).collect(toImmutableSet())) - .build()); + Multimap tableColumnMap = ImmutableMultimap.builder() + .putAll(tableName, metadata.getColumnHandles(session, tableHandle).keySet().stream().map(column -> new Subfield(column, ImmutableList.of())).collect(toImmutableSet())) + .build(); + analysis.addTableColumnAndSubfieldReferences(accessControl, session.getIdentity(), tableColumnMap, tableColumnMap); try { accessControl.checkCanInsertIntoTable(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), tableName); } @@ -1955,16 +1953,12 @@ private Scope analyzeJoinUsing(Join node, List columns, Optional tableColumnMap = ImmutableMultimap.of(leftField.get().getField().getOriginTable().get(), new Subfield(leftField.get().getField().getOriginColumnName().get(), ImmutableList.of())); + analysis.addTableColumnAndSubfieldReferences(accessControl, session.getIdentity(), tableColumnMap, tableColumnMap); } if (rightField.get().getField().getOriginTable().isPresent() && rightField.get().getField().getOriginColumnName().isPresent()) { - analysis.addTableColumnAndSubfieldReferences( - accessControl, - session.getIdentity(), - ImmutableMultimap.of(rightField.get().getField().getOriginTable().get(), new Subfield(rightField.get().getField().getOriginColumnName().get(), ImmutableList.of()))); + Multimap tableColumnMap = ImmutableMultimap.of(rightField.get().getField().getOriginTable().get(), new Subfield(rightField.get().getField().getOriginColumnName().get(), ImmutableList.of())); + analysis.addTableColumnAndSubfieldReferences(accessControl, session.getIdentity(), tableColumnMap, tableColumnMap); } } diff --git a/presto-main/src/main/java/com/facebook/presto/testing/TestingAccessControlManager.java b/presto-main/src/main/java/com/facebook/presto/testing/TestingAccessControlManager.java index 332bc32bcced1..24828eae3c9ce 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/TestingAccessControlManager.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/TestingAccessControlManager.java @@ -301,7 +301,7 @@ public void checkCanSetCatalogSessionProperty(TransactionId transactionId, Ident @Override public void checkCanSelectFromColumns(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName, Set columnOrSubfieldNames) { - Set columns = columnOrSubfieldNames.stream().map(subfield -> subfield.getRootName()).collect(toImmutableSet()); + Set columns = columnOrSubfieldNames.stream().map(subfield -> subfield.toString()).collect(toImmutableSet()); if (shouldDenyPrivilege(identity.getUser(), tableName.getObjectName(), SELECT_COLUMN)) { denySelectColumns(tableName.toString(), columns); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java index c8dfbde771846..9282abd2af9d2 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java @@ -256,7 +256,13 @@ public void setup() new ColumnMetadata("b", RowType.from(ImmutableList.of( new RowType.Field(Optional.of("w"), BIGINT), new RowType.Field(Optional.of("x"), - new ArrayType(new ArrayType(RowType.from(ImmutableList.of(new RowType.Field(Optional.of("y"), BIGINT))))))))))), + new ArrayType(new ArrayType(RowType.from(ImmutableList.of(new RowType.Field(Optional.of("y"), BIGINT))))))))), + new ColumnMetadata("c", RowType.from(ImmutableList.of( + new RowType.Field( + Optional.of("x"), + new ArrayType(RowType.from(ImmutableList.of( + new RowType.Field(Optional.of("x"), BIGINT), + new RowType.Field(Optional.of("y"), BIGINT)))))))))), false)); // table with columns containing special characters diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestColumnAndSubfieldAnalyzer.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestColumnAndSubfieldAnalyzer.java index 0436f08d4c810..08c14a8408d67 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestColumnAndSubfieldAnalyzer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestColumnAndSubfieldAnalyzer.java @@ -38,6 +38,46 @@ public class TestColumnAndSubfieldAnalyzer extends AbstractAnalyzerTest { + @Test + public void testCardinality() + { + assertTableColumns( + "SELECT cardinality(a) FROM tpch.s1.t11", + ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of())); + + assertTableColumns( + "SELECT transform(b.x, yo -> cardinality(yo)) FROM tpch.s1.t11", + ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of())); + } + + @Test + public void testTransform() + { + assertTableColumns( + "SELECT transform(a, yo -> yo.x + yo.y) FROM tpch.s1.t11", + ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of("a.x", "a.y"))); + assertTableColumns( + "SELECT transform(a, yo -> yo) FROM tpch.s1.t11", + ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of("a"))); + assertTableColumns( + "SELECT transform(c.x, yo -> yo.x) FROM tpch.s1.t11", + ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of("c.x.x"))); + assertTableColumns( + "SELECT transform(c.x, yo -> yo[1]) FROM tpch.s1.t11", + ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of("c.x.x"))); + assertTableColumns( + "SELECT transform(b.x, yo -> transform(yo, yoo -> yoo.y)) FROM tpch.s1.t11", + ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of("b.x.y"))); + assertTableColumns( + "SELECT transform(tbl.b.x, yo -> transform(yo, yoo -> yoo.y)) FROM tpch.s1.t11 tbl", + ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of("b.x.y"))); + + // We only parse lambda in transform, when first expression is simple + assertTableColumns( + "SELECT transform(reverse(a), yo -> yo.x + yo.y) FROM tpch.s1.t11", + ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of("a"))); + } + @Test public void testSelect() { diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/integration/TestRaptorDistributedQueries.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/integration/TestRaptorDistributedQueries.java index 5b0b14cb33fc9..5bf57d1259e40 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/integration/TestRaptorDistributedQueries.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/integration/TestRaptorDistributedQueries.java @@ -42,4 +42,11 @@ public void testLargeQuerySuccess() { // TODO: disabled until we fix stackoverflow error in ExpressionTreeRewriter } + + @Test + @Override + public void testSubfieldAccessControl() + { + // disabled as raptor doesn't support complex types + } } diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/integration/TestRaptorDistributedQueriesBucketed.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/integration/TestRaptorDistributedQueriesBucketed.java index 6d07d037d5d41..e7c6444b2b2fa 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/integration/TestRaptorDistributedQueriesBucketed.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/integration/TestRaptorDistributedQueriesBucketed.java @@ -15,6 +15,7 @@ import com.facebook.presto.testing.QueryRunner; import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; import static com.facebook.presto.raptor.RaptorQueryRunner.createRaptorQueryRunner; @@ -28,6 +29,13 @@ protected QueryRunner createQueryRunner() return createRaptorQueryRunner(ImmutableMap.of(), true, true, false, ImmutableMap.of("storage.orc.optimized-writer-stage", "ENABLED_AND_VALIDATED")); } + @Test + @Override + public void testSubfieldAccessControl() + { + // disabled as raptor doesn't support complex types + } + @Override protected boolean supportsNotNullColumns() { diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestDistributedQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestDistributedQueries.java index 9a3ed92046a9d..aa119a9975e49 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestDistributedQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestDistributedQueries.java @@ -130,6 +130,38 @@ public void testResetSession() assertEquals(result.getResetSessionProperties(), ImmutableSet.of(TESTING_CATALOG + ".connector_string")); } + @Test + public void testSubfieldAccessControl() + { + Session session = Session.builder(getSession()) + .setSystemProperty("check_access_control_with_subfields", "true") + .build(); + assertUpdate( + session, + "CREATE TABLE test_subfield AS SELECT CAST(ROW(1, 2, ARRAY[ROW(1, 2)]) AS ROW(f1 int, f2 int, f3 ARRAY)) x", + 1); + assertAccessAllowed(session, "SELECT x.f1 from test_subfield"); + assertAccessAllowed(session, "SELECT x.f1 from test_subfield", privilege("x.f2", SELECT_COLUMN)); + assertAccessDenied( + session, + "SELECT x.f1 from test_subfield", + ".*Cannot select from columns \\[x.f1\\].*", + privilege("x.f1", SELECT_COLUMN)); + + assertAccessDenied(session, + "SELECT transform(x.f3, col -> col.ff1) from test_subfield", + ".*Cannot select from columns \\[x.f3.ff1\\].*", privilege("x.f3.ff1", SELECT_COLUMN)); + assertAccessAllowed(session, + "SELECT transform(x.f3, col -> col.ff1) from test_subfield", + privilege("x.f3.ff2", SELECT_COLUMN)); + + assertAccessAllowed(session, + "SELECT cardinality(x.f3) from test_subfield", + privilege("x.f3", SELECT_COLUMN)); + + assertUpdate("DROP TABLE test_subfield"); + } + @Test public void testCreateTable() {