Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ public void testCreateTableAsSelect()
"SELECT 0");
}

@Test
@Override
public void testSubfieldAccessControl()
{
// disabled as accumulo doesn't support complex types
}

@Override
public void testDelete()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -768,15 +768,15 @@ public JoinUsingAnalysis getJoinUsing(Join node)
return joinUsing.get(NodeRef.of(node));
}

public void addTableColumnAndSubfieldReferences(AccessControl accessControl, Identity identity, Multimap<QualifiedObjectName, Subfield> tableColumnMap)
public void addTableColumnAndSubfieldReferences(AccessControl accessControl, Identity identity, Multimap<QualifiedObjectName, Subfield> tableColumnMap, Multimap<QualifiedObjectName, Subfield> tableColumnMapForAccessControl)
{
AccessControlInfo accessControlInfo = new AccessControlInfo(accessControl, identity);
Map<QualifiedObjectName, Set<String>> columnReferences = tableColumnReferences.computeIfAbsent(accessControlInfo, k -> new LinkedHashMap<>());
tableColumnMap.asMap()
.forEach((key, value) -> columnReferences.computeIfAbsent(key, k -> new HashSet<>()).addAll(value.stream().map(Subfield::getRootName).collect(toImmutableSet())));

Map<QualifiedObjectName, Set<Subfield>> columnAndSubfieldReferences = tableColumnAndSubfieldReferences.computeIfAbsent(accessControlInfo, k -> new LinkedHashMap<>());
tableColumnMap.asMap()
tableColumnMapForAccessControl.asMap()
.forEach((key, value) -> columnAndSubfieldReferences.computeIfAbsent(key, k -> new HashSet<>()).addAll(value));
}

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.analyzer;

import com.facebook.presto.common.Subfield;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.SubscriptExpression;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.sql.tree.StackableAstVisitor.StackableAstVisitorContext;
import static com.google.common.base.Preconditions.checkState;
import static java.lang.Math.toIntExact;
import static java.util.Collections.reverse;

public class FunctionArgumentCheckerForAccessControlUtils
{
private static final QualifiedName TRANSFORM = QualifiedName.of("transform");
private static final QualifiedName CARDINALITY = QualifiedName.of("cardinality");

private FunctionArgumentCheckerForAccessControlUtils() {}

// Returns whether function argument at `argumentIndex` for function `node` needs to be checked
// for column level access control.
// For e.g., consider SQL `transform(arr, col -> col.x)`
// Here, we only need to check for access of subfield `x` in column `arr` which is of type `Array<struct>`.
// So we can just parse lambda and ignore the first argument for access checks.
public static boolean isUnusedArgumentForAccessControl(FunctionCall node, int argumentIndex, ExpressionAnalyzer.Context context)
{
if (node.getName().equals(TRANSFORM)) {
checkState(node.getArguments().size() == 2);
return argumentIndex == 0;
}
if (node.getName().equals(CARDINALITY)) {
checkState(node.getArguments().size() == 1);
return argumentIndex == 0;
}
return false;
}

// Parses arguments of function `node` which are a lambda expression, and returns a map
// of their lambda arguments to resolved subfield.
// For e.g., consider SQL `SELECT transform(arr, col -> col.x) FROM table`
// Return value = Map('col' -> ResolvedSubfield(table.arr))
public static Map<Identifier, ResolvedSubfield> getResolvedLambdaArguments(
FunctionCall node,
StackableAstVisitorContext<ExpressionAnalyzer.Context> context,
Map<NodeRef<Expression>, Type> expressionTypes)
{
ImmutableMap.Builder<Identifier, ResolvedSubfield> resolvedLambdaArguments = ImmutableMap.builder();
if (node.getName().equals(TRANSFORM)) {
checkState(node.getArguments().size() == 2);
if (!(node.getArguments().get(1) instanceof LambdaExpression)) {
return ImmutableMap.of();
}
Expression arrayExpression = node.getArguments().get(0);
LambdaExpression lambdaExpression = ((LambdaExpression) node.getArguments().get(1));
Optional<ResolvedSubfield> resolvedSubfield = resolveSubfield(arrayExpression, context, expressionTypes);
if (resolvedSubfield.isPresent()) {
resolvedLambdaArguments.put(
lambdaExpression.getArguments().get(0).getName(),
resolvedSubfield.get());
}
}
return resolvedLambdaArguments.build();
}

public static Optional<ResolvedSubfield> resolveSubfield(
Expression node,
StackableAstVisitorContext<ExpressionAnalyzer.Context> context,
Map<NodeRef<Expression>, Type> expressionTypes)
{
// If expression is nested with multiple dereferences and subscripts, we only look at the topmost one.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? why not use recursion to go into all of them?

Copy link
Copy Markdown
Contributor Author

@pranjalssh pranjalssh Jan 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is moved from ExpressionAnalyzer as well. This function is called once for each dereference, when multiple dereferences are nested. So, we have this check that only looks at the dereference chain once.

if (!isTopMostReference(node, context)) {
return Optional.empty();
}

Scope scope = context.getContext().getScope();
Expression childNode = node;
List<Subfield.PathElement> columnDereferences = new ArrayList<>();
while (true) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think all the different cases here are a bit confusing. Needs some comments, but I think it would be clearer if written as a visitor instead of a while loop with each case determining whether to continue the loop. It would also be easy to introduce an infinite loop in this logic(continue gets called without the child nod getting updated to something that would break the loop eventually)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to modify this logic in this PR. This is an existing logic moved from ExpressionAnalyzer to here, which is already tested. I can put up a todo to refactor it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. I guess it's okay since the code already existed. would be good to refactor though as a follow up.

// Dereference row/array/map expressions
if (childNode instanceof SubscriptExpression) {
SubscriptExpression subscriptExpression = (SubscriptExpression) childNode;
childNode = subscriptExpression.getBase();
Type baseType = expressionTypes.get(NodeRef.of(childNode));
if (baseType == null || !(baseType instanceof RowType)) {
continue;
}
int index = toIntExact(((LongLiteral) subscriptExpression.getIndex()).getValue());
RowType baseRowType = (RowType) baseType;
Optional<String> dereference = baseRowType.getFields().get(index - 1).getName();
if (!dereference.isPresent()) {
break;
}
columnDereferences.add(new Subfield.NestedField(dereference.get()));
continue;
}

QualifiedName childQualifiedName;
// Dereference subfield expressions
if (childNode instanceof DereferenceExpression) {
childQualifiedName = DereferenceExpression.getQualifiedName((DereferenceExpression) childNode);
}
// Base case
else if (childNode instanceof Identifier) {
childQualifiedName = QualifiedName.of(((Identifier) childNode).getValue());
}
else {
break;
}
// If we found the full de-referenced expression, return it as a ResolvedSubfield
if (childQualifiedName != null) {
Optional<ResolvedField> resolvedField = scope.tryResolveField(childNode, childQualifiedName);
if (resolvedField.isPresent() && !resolvedField.get().getField().getOriginTable().isPresent()) {
// Try to resolve using lambda expressions
Optional<ResolvedSubfield> resolvedSubField = Optional.ofNullable(context.getContext().getResolvedLambdaArguments().get(childNode));
if (resolvedSubField.isPresent()) {
resolvedField = Optional.of(resolvedSubField.get().getResolvedField());
columnDereferences.addAll(Lists.reverse(resolvedSubField.get().getSubfield().getPath()));
}
}
if (resolvedField.isPresent() &&
resolvedField.get().getField().getOriginColumnName().isPresent() &&
resolvedField.get().getField().getOriginTable().isPresent()) {
reverse(columnDereferences);
return Optional.of(new ResolvedSubfield(
resolvedField.get(),
new Subfield(resolvedField.get().getField().getOriginColumnName().get(), columnDereferences)));
}
}
// If we cannot resolve full de-referenced name, that means that there are
// more dereferences to be resolved, so we continue the while loop with new childNode.
if (childNode instanceof DereferenceExpression) {
columnDereferences.add(new Subfield.NestedField(((DereferenceExpression) childNode).getField().getValue()));
childNode = ((DereferenceExpression) childNode).getBase();
continue;
}
break;
}
return Optional.empty();
}

public static boolean isDereferenceOrSubscript(Expression node)
{
return node instanceof DereferenceExpression || node instanceof SubscriptExpression;
}

public static boolean isTopMostReference(Expression node, StackableAstVisitorContext<ExpressionAnalyzer.Context> context)
{
if (!context.getPreviousNode().isPresent()) {
return true;
}
return !isDereferenceOrSubscript((Expression) context.getPreviousNode().get());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.analyzer;

import com.facebook.presto.common.Subfield;

import static java.util.Objects.requireNonNull;

public class ResolvedSubfield
{
private final ResolvedField resolvedField;
private final Subfield subfield;

public ResolvedSubfield(ResolvedField resolvedField, Subfield subfield)
{
this.resolvedField = requireNonNull(resolvedField, "resolvedField is null");
this.subfield = requireNonNull(subfield, "subfield is null");
}

public ResolvedField getResolvedField()
{
return resolvedField;
}

public Subfield getSubfield()
{
return subfield;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -595,12 +595,10 @@ protected Scope visitAnalyze(Analyze node, Optional<Scope> scope)
.orElseThrow(() -> (new SemanticException(MISSING_TABLE, node, "Table '%s' does not exist", tableName)));

// user must have read and insert permission in order to analyze stats of a table
analysis.addTableColumnAndSubfieldReferences(
accessControl,
session.getIdentity(),
ImmutableMultimap.<QualifiedObjectName, Subfield>builder()
.putAll(tableName, metadata.getColumnHandles(session, tableHandle).keySet().stream().map(column -> new Subfield(column, ImmutableList.of())).collect(toImmutableSet()))
.build());
Multimap<QualifiedObjectName, Subfield> tableColumnMap = ImmutableMultimap.<QualifiedObjectName, Subfield>builder()
.putAll(tableName, metadata.getColumnHandles(session, tableHandle).keySet().stream().map(column -> new Subfield(column, ImmutableList.of())).collect(toImmutableSet()))
.build();
analysis.addTableColumnAndSubfieldReferences(accessControl, session.getIdentity(), tableColumnMap, tableColumnMap);
try {
accessControl.checkCanInsertIntoTable(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), tableName);
}
Expand Down Expand Up @@ -1955,16 +1953,12 @@ private Scope analyzeJoinUsing(Join node, List<Identifier> columns, Optional<Sco
analysis.addColumnReference(NodeRef.of(column), FieldId.from(leftField.get()));
analysis.addColumnReference(NodeRef.of(column), FieldId.from(rightField.get()));
if (leftField.get().getField().getOriginTable().isPresent() && leftField.get().getField().getOriginColumnName().isPresent()) {
analysis.addTableColumnAndSubfieldReferences(
accessControl,
session.getIdentity(),
ImmutableMultimap.of(leftField.get().getField().getOriginTable().get(), new Subfield(leftField.get().getField().getOriginColumnName().get(), ImmutableList.of())));
Multimap<QualifiedObjectName, Subfield> tableColumnMap = ImmutableMultimap.of(leftField.get().getField().getOriginTable().get(), new Subfield(leftField.get().getField().getOriginColumnName().get(), ImmutableList.of()));
analysis.addTableColumnAndSubfieldReferences(accessControl, session.getIdentity(), tableColumnMap, tableColumnMap);
}
if (rightField.get().getField().getOriginTable().isPresent() && rightField.get().getField().getOriginColumnName().isPresent()) {
analysis.addTableColumnAndSubfieldReferences(
accessControl,
session.getIdentity(),
ImmutableMultimap.of(rightField.get().getField().getOriginTable().get(), new Subfield(rightField.get().getField().getOriginColumnName().get(), ImmutableList.of())));
Multimap<QualifiedObjectName, Subfield> tableColumnMap = ImmutableMultimap.of(rightField.get().getField().getOriginTable().get(), new Subfield(rightField.get().getField().getOriginColumnName().get(), ImmutableList.of()));
analysis.addTableColumnAndSubfieldReferences(accessControl, session.getIdentity(), tableColumnMap, tableColumnMap);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ public void checkCanSetCatalogSessionProperty(TransactionId transactionId, Ident
@Override
public void checkCanSelectFromColumns(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName, Set<Subfield> columnOrSubfieldNames)
{
Set<String> columns = columnOrSubfieldNames.stream().map(subfield -> subfield.getRootName()).collect(toImmutableSet());
Set<String> columns = columnOrSubfieldNames.stream().map(subfield -> subfield.toString()).collect(toImmutableSet());
if (shouldDenyPrivilege(identity.getUser(), tableName.getObjectName(), SELECT_COLUMN)) {
denySelectColumns(tableName.toString(), columns);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,13 @@ public void setup()
new ColumnMetadata("b", RowType.from(ImmutableList.of(
new RowType.Field(Optional.of("w"), BIGINT),
new RowType.Field(Optional.of("x"),
new ArrayType(new ArrayType(RowType.from(ImmutableList.of(new RowType.Field(Optional.of("y"), BIGINT))))))))))),
new ArrayType(new ArrayType(RowType.from(ImmutableList.of(new RowType.Field(Optional.of("y"), BIGINT))))))))),
new ColumnMetadata("c", RowType.from(ImmutableList.of(
new RowType.Field(
Optional.of("x"),
new ArrayType(RowType.from(ImmutableList.of(
new RowType.Field(Optional.of("x"), BIGINT),
new RowType.Field(Optional.of("y"), BIGINT)))))))))),
false));

// table with columns containing special characters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,46 @@
public class TestColumnAndSubfieldAnalyzer
extends AbstractAnalyzerTest
{
@Test
public void testCardinality()
{
assertTableColumns(
"SELECT cardinality(a) FROM tpch.s1.t11",
ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of()));

assertTableColumns(
"SELECT transform(b.x, yo -> cardinality(yo)) FROM tpch.s1.t11",
ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of()));
}

@Test
public void testTransform()
{
assertTableColumns(
Comment thread
pranjalssh marked this conversation as resolved.
Outdated
"SELECT transform(a, yo -> yo.x + yo.y) FROM tpch.s1.t11",
ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of("a.x", "a.y")));
assertTableColumns(
"SELECT transform(a, yo -> yo) FROM tpch.s1.t11",
ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of("a")));
assertTableColumns(
"SELECT transform(c.x, yo -> yo.x) FROM tpch.s1.t11",
ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of("c.x.x")));
assertTableColumns(
"SELECT transform(c.x, yo -> yo[1]) FROM tpch.s1.t11",
ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of("c.x.x")));
assertTableColumns(
"SELECT transform(b.x, yo -> transform(yo, yoo -> yoo.y)) FROM tpch.s1.t11",
ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of("b.x.y")));
assertTableColumns(
"SELECT transform(tbl.b.x, yo -> transform(yo, yoo -> yoo.y)) FROM tpch.s1.t11 tbl",
ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of("b.x.y")));

// We only parse lambda in transform, when first expression is simple
assertTableColumns(
"SELECT transform(reverse(a), yo -> yo.x + yo.y) FROM tpch.s1.t11",
ImmutableMap.of(QualifiedObjectName.valueOf("tpch.s1.t11"), ImmutableSet.of("a")));
}

@Test
public void testSelect()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,11 @@ public void testLargeQuerySuccess()
{
// TODO: disabled until we fix stackoverflow error in ExpressionTreeRewriter
}

@Test
@Override
public void testSubfieldAccessControl()
{
// disabled as raptor doesn't support complex types
}
}
Loading