Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1132,24 +1132,15 @@ else if (frame.getType() == GROUPS) {
List<TypeSignature> arguments = functionMetadata.getArgumentTypes();
String functionName = functionMetadata.getName().toString();

if (!argumentTypes.isEmpty() && "map".equals(arguments.get(0).getBase())) {
if (arguments.size() > 1) {
arguments.stream()
.skip(1)
.filter(arg -> {
String base = arg.getBase();
return "function".equals(base) || "lambda".equals(base);
})
.findFirst()
.ifPresent(arg -> {
String warningMessage = createWarningMessage(node,
String.format("Function '%s' uses a lambda on large maps which is expensive. Consider using map_subset", functionName));
warningCollector.add(new PrestoWarning(PERFORMANCE_WARNING, warningMessage));
});
}
else if (arguments.size() == 1) {
String base = arguments.get(0).getBase();
if ("function".equals(base) || "lambda".equals(base)) {
if (!argumentTypes.isEmpty() && "map".equals(arguments.get(0).getBase()) &&
"map_filter".equalsIgnoreCase(functionMetadata.getName().getObjectName()) &&
arguments.size() > 1 && node.getArguments().size() >= 2) {
Expression mapArg = node.getArguments().get(0);
Expression lambdaArg = node.getArguments().get(1);

if (containsFeatures(mapArg) && lambdaArg instanceof LambdaExpression) {
LambdaExpression lambda = (LambdaExpression) lambdaArg;
if (lambda.getArguments().size() == 2 && isKeyOnlyMembershipFilter(lambda)) {
String warningMessage = createWarningMessage(node,
String.format("Function '%s' uses a lambda on large maps which is expensive. Consider using map_subset", functionName));
warningCollector.add(new PrestoWarning(PERFORMANCE_WARNING, warningMessage));
Expand Down Expand Up @@ -1216,6 +1207,100 @@ private String createWarningMessage(Node node, String message)
}
}

private boolean isKeyOnlyMembershipFilter(LambdaExpression lambda)
{
String valueArgName = lambda.getArguments().get(1).getName().getValue();
Expression body = lambda.getBody();

if (expressionReferencesName(body, valueArgName)) {
return false;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: The 'isSimpleKeyEquality' method may not handle all logical cases, especially with AND operators.

Consider updating the logic to also handle AND operators if key-only membership applies to them.

}

return isSimpleKeyEquality(body);
}

private boolean expressionReferencesName(Expression expression, String name)
{
if (expression == null) {
return false;
}
if (expression instanceof Identifier) {
return ((Identifier) expression).getValue().equalsIgnoreCase(name);
}
if (expression instanceof ComparisonExpression) {
ComparisonExpression comp = (ComparisonExpression) expression;
return expressionReferencesName(comp.getLeft(), name) || expressionReferencesName(comp.getRight(), name);
}
if (expression instanceof LogicalBinaryExpression) {
LogicalBinaryExpression logical = (LogicalBinaryExpression) expression;
return expressionReferencesName(logical.getLeft(), name) || expressionReferencesName(logical.getRight(), name);
}
if (expression instanceof InPredicate) {
InPredicate inPred = (InPredicate) expression;
return expressionReferencesName(inPred.getValue(), name) || expressionReferencesName(inPred.getValueList(), name);
}
if (expression instanceof InListExpression) {
InListExpression inList = (InListExpression) expression;
for (Expression value : inList.getValues()) {
if (expressionReferencesName(value, name)) {
return true;
}
}
}
if (expression instanceof ArithmeticBinaryExpression) {
ArithmeticBinaryExpression arith = (ArithmeticBinaryExpression) expression;
return expressionReferencesName(arith.getLeft(), name) || expressionReferencesName(arith.getRight(), name);
}
if (expression instanceof FunctionCall) {
FunctionCall func = (FunctionCall) expression;
for (Expression arg : func.getArguments()) {
if (expressionReferencesName(arg, name)) {
return true;
}
}
}
// Literals don't reference any names
return false;
}

private boolean containsFeatures(Expression expression)
{
if (expression instanceof Identifier) {
return ((Identifier) expression).getValue().toLowerCase().contains("features");
}
if (expression instanceof SymbolReference) {
return ((SymbolReference) expression).getName().toLowerCase().contains("features");
}
if (expression instanceof DereferenceExpression) {
DereferenceExpression deref = (DereferenceExpression) expression;
return containsFeatures(deref.getBase()) || deref.getField().getValue().toLowerCase().contains("features");
}
return false;
}

private boolean isSimpleKeyEquality(Expression expression)
{
if (expression instanceof ComparisonExpression) {
ComparisonExpression comparison = (ComparisonExpression) expression;
return comparison.getOperator() == ComparisonExpression.Operator.EQUAL;
}
if (expression instanceof InPredicate) {
return true;
}
if (expression instanceof LogicalBinaryExpression) {
LogicalBinaryExpression logical = (LogicalBinaryExpression) expression;
if (logical.getOperator() == LogicalBinaryExpression.Operator.OR) {
return isSimpleKeyEquality(logical.getLeft()) && isSimpleKeyEquality(logical.getRight());
}
}
if (expression instanceof FunctionCall) {
FunctionCall func = (FunctionCall) expression;
String funcName = func.getName().toString();
return funcName.equalsIgnoreCase("contains") || funcName.equalsIgnoreCase("presto.default.contains");
}
return false;
}

private void analyzeFrameRangeOffset(Expression offsetValue, FrameBound.Type boundType, StackableAstVisitorContext<Context> context, Window window)
{
if (!window.getOrderBy().isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,30 +168,27 @@ void testNoORWarning()
@Test
public void testMapFilterWarnings()
{
assertHasWarning(
analyzeWithWarnings("SELECT map_filter(x, (k, v) -> v > 1) FROM (VALUES (map(ARRAY[1,2], ARRAY[2,3]))) AS t(x)"),
PERFORMANCE_WARNING,
"Function 'presto.default.map_filter' uses a lambda on large maps which is expensive. Consider using map_subset");
assertNoWarning(analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> v > 1) FROM (VALUES (map(ARRAY[1,2], ARRAY[2,3]))) AS t(user_features)"));

assertHasWarning(
analyzeWithWarnings("SELECT map_filter(x, (k, v) -> k = 2) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(x)"),
analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> k = 2) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)"),
PERFORMANCE_WARNING,
"Function 'presto.default.map_filter' uses a lambda on large maps which is expensive. Consider using map_subset");

assertHasWarning(
analyzeWithWarnings("SELECT map_filter(x, (k, v) -> k IN (1, 3)) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(x)"),
analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> k IN (1, 3)) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)"),
PERFORMANCE_WARNING,
"Function 'presto.default.map_filter' uses a lambda on large maps which is expensive. Consider using map_subset");

assertHasWarning(
analyzeWithWarnings("SELECT map_filter(x, (k, v) -> v IN (20, 30)) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(x)"),
PERFORMANCE_WARNING,
"Function 'presto.default.map_filter' uses a lambda on large maps which is expensive. Consider using map_subset");
assertNoWarning(analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> v IN (20, 30)) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)"));

assertHasWarning(
analyzeWithWarnings("SELECT map_filter(x, (k, v) -> k + v > 25) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(x)"),
PERFORMANCE_WARNING,
"Function 'presto.default.map_filter' uses a lambda on large maps which is expensive. Consider using map_subset");
assertNoWarning(analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> k + v > 25) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)"));

assertNoWarning(analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> k > 2) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)"));

assertNoWarning(analyzeWithWarnings("SELECT transform_values(user_features, (k, v) -> v * 2) FROM (VALUES (map(ARRAY[1,2], ARRAY[2,3]))) AS t(user_features)"));

assertNoWarning(analyzeWithWarnings("SELECT map_filter(x, (k, v) -> k = 2) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(x)"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ public void testMapWithDoubleKeysProducesWarnings()
assertWarnings(queryRunner, TEST_SESSION, query, ImmutableSet.of(SEMANTIC_WARNING.toWarningCode()));

query = "select transform_keys(map(ARRAY [25.5E0, 26.5E0, 27.5E0], ARRAY [25.5E0, 26.5E0, 27.5E0]), (k, v) -> k + v)";
assertWarnings(queryRunner, TEST_SESSION, query, ImmutableSet.of(SEMANTIC_WARNING.toWarningCode(), PERFORMANCE_WARNING.toWarningCode()));
assertWarnings(queryRunner, TEST_SESSION, query, ImmutableSet.of(SEMANTIC_WARNING.toWarningCode()));

query = "SELECT histogram(RETAILPRICE) FROM tpch.tiny.part";
assertWarnings(queryRunner, TEST_SESSION, query, ImmutableSet.of(SEMANTIC_WARNING.toWarningCode()));
Expand Down
Loading