diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java index e2739aca9a69b..48917e5fcb61d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java @@ -1132,24 +1132,15 @@ else if (frame.getType() == GROUPS) { List arguments = functionMetadata.getArgumentTypes(); String functionName = functionMetadata.getName().toString(); - if (!argumentTypes.isEmpty() && "map".equals(arguments.get(0).getBase())) { - if (arguments.size() > 1) { - arguments.stream() - .skip(1) - .filter(arg -> { - String base = arg.getBase(); - return "function".equals(base) || "lambda".equals(base); - }) - .findFirst() - .ifPresent(arg -> { - String warningMessage = createWarningMessage(node, - String.format("Function '%s' uses a lambda on large maps which is expensive. Consider using map_subset", functionName)); - warningCollector.add(new PrestoWarning(PERFORMANCE_WARNING, warningMessage)); - }); - } - else if (arguments.size() == 1) { - String base = arguments.get(0).getBase(); - if ("function".equals(base) || "lambda".equals(base)) { + if (!argumentTypes.isEmpty() && "map".equals(arguments.get(0).getBase()) && + "map_filter".equalsIgnoreCase(functionMetadata.getName().getObjectName()) && + arguments.size() > 1 && node.getArguments().size() >= 2) { + Expression mapArg = node.getArguments().get(0); + Expression lambdaArg = node.getArguments().get(1); + + if (containsFeatures(mapArg) && lambdaArg instanceof LambdaExpression) { + LambdaExpression lambda = (LambdaExpression) lambdaArg; + if (lambda.getArguments().size() == 2 && isKeyOnlyMembershipFilter(lambda)) { String warningMessage = createWarningMessage(node, String.format("Function '%s' uses a lambda on large maps which is expensive. Consider using map_subset", functionName)); warningCollector.add(new PrestoWarning(PERFORMANCE_WARNING, warningMessage)); @@ -1216,6 +1207,100 @@ private String createWarningMessage(Node node, String message) } } + private boolean isKeyOnlyMembershipFilter(LambdaExpression lambda) + { + String valueArgName = lambda.getArguments().get(1).getName().getValue(); + Expression body = lambda.getBody(); + + if (expressionReferencesName(body, valueArgName)) { + return false; + } + + return isSimpleKeyEquality(body); + } + + private boolean expressionReferencesName(Expression expression, String name) + { + if (expression == null) { + return false; + } + if (expression instanceof Identifier) { + return ((Identifier) expression).getValue().equalsIgnoreCase(name); + } + if (expression instanceof ComparisonExpression) { + ComparisonExpression comp = (ComparisonExpression) expression; + return expressionReferencesName(comp.getLeft(), name) || expressionReferencesName(comp.getRight(), name); + } + if (expression instanceof LogicalBinaryExpression) { + LogicalBinaryExpression logical = (LogicalBinaryExpression) expression; + return expressionReferencesName(logical.getLeft(), name) || expressionReferencesName(logical.getRight(), name); + } + if (expression instanceof InPredicate) { + InPredicate inPred = (InPredicate) expression; + return expressionReferencesName(inPred.getValue(), name) || expressionReferencesName(inPred.getValueList(), name); + } + if (expression instanceof InListExpression) { + InListExpression inList = (InListExpression) expression; + for (Expression value : inList.getValues()) { + if (expressionReferencesName(value, name)) { + return true; + } + } + } + if (expression instanceof ArithmeticBinaryExpression) { + ArithmeticBinaryExpression arith = (ArithmeticBinaryExpression) expression; + return expressionReferencesName(arith.getLeft(), name) || expressionReferencesName(arith.getRight(), name); + } + if (expression instanceof FunctionCall) { + FunctionCall func = (FunctionCall) expression; + for (Expression arg : func.getArguments()) { + if (expressionReferencesName(arg, name)) { + return true; + } + } + } + // Literals don't reference any names + return false; + } + + private boolean containsFeatures(Expression expression) + { + if (expression instanceof Identifier) { + return ((Identifier) expression).getValue().toLowerCase().contains("features"); + } + if (expression instanceof SymbolReference) { + return ((SymbolReference) expression).getName().toLowerCase().contains("features"); + } + if (expression instanceof DereferenceExpression) { + DereferenceExpression deref = (DereferenceExpression) expression; + return containsFeatures(deref.getBase()) || deref.getField().getValue().toLowerCase().contains("features"); + } + return false; + } + + private boolean isSimpleKeyEquality(Expression expression) + { + if (expression instanceof ComparisonExpression) { + ComparisonExpression comparison = (ComparisonExpression) expression; + return comparison.getOperator() == ComparisonExpression.Operator.EQUAL; + } + if (expression instanceof InPredicate) { + return true; + } + if (expression instanceof LogicalBinaryExpression) { + LogicalBinaryExpression logical = (LogicalBinaryExpression) expression; + if (logical.getOperator() == LogicalBinaryExpression.Operator.OR) { + return isSimpleKeyEquality(logical.getLeft()) && isSimpleKeyEquality(logical.getRight()); + } + } + if (expression instanceof FunctionCall) { + FunctionCall func = (FunctionCall) expression; + String funcName = func.getName().toString(); + return funcName.equalsIgnoreCase("contains") || funcName.equalsIgnoreCase("presto.default.contains"); + } + return false; + } + private void analyzeFrameRangeOffset(Expression offsetValue, FrameBound.Type boundType, StackableAstVisitorContext context, Window window) { if (!window.getOrderBy().isPresent()) { diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java index a1944fc9c4765..2e3f4d2534806 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java @@ -168,30 +168,27 @@ void testNoORWarning() @Test public void testMapFilterWarnings() { - assertHasWarning( - analyzeWithWarnings("SELECT map_filter(x, (k, v) -> v > 1) FROM (VALUES (map(ARRAY[1,2], ARRAY[2,3]))) AS t(x)"), - PERFORMANCE_WARNING, - "Function 'presto.default.map_filter' uses a lambda on large maps which is expensive. Consider using map_subset"); + assertNoWarning(analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> v > 1) FROM (VALUES (map(ARRAY[1,2], ARRAY[2,3]))) AS t(user_features)")); assertHasWarning( - analyzeWithWarnings("SELECT map_filter(x, (k, v) -> k = 2) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(x)"), + analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> k = 2) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)"), PERFORMANCE_WARNING, "Function 'presto.default.map_filter' uses a lambda on large maps which is expensive. Consider using map_subset"); assertHasWarning( - analyzeWithWarnings("SELECT map_filter(x, (k, v) -> k IN (1, 3)) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(x)"), + analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> k IN (1, 3)) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)"), PERFORMANCE_WARNING, "Function 'presto.default.map_filter' uses a lambda on large maps which is expensive. Consider using map_subset"); - assertHasWarning( - analyzeWithWarnings("SELECT map_filter(x, (k, v) -> v IN (20, 30)) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(x)"), - PERFORMANCE_WARNING, - "Function 'presto.default.map_filter' uses a lambda on large maps which is expensive. Consider using map_subset"); + assertNoWarning(analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> v IN (20, 30)) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)")); - assertHasWarning( - analyzeWithWarnings("SELECT map_filter(x, (k, v) -> k + v > 25) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(x)"), - PERFORMANCE_WARNING, - "Function 'presto.default.map_filter' uses a lambda on large maps which is expensive. Consider using map_subset"); + assertNoWarning(analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> k + v > 25) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)")); + + assertNoWarning(analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> k > 2) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)")); + + assertNoWarning(analyzeWithWarnings("SELECT transform_values(user_features, (k, v) -> v * 2) FROM (VALUES (map(ARRAY[1,2], ARRAY[2,3]))) AS t(user_features)")); + + assertNoWarning(analyzeWithWarnings("SELECT map_filter(x, (k, v) -> k = 2) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(x)")); } @Test diff --git a/presto-tests/src/test/java/com/facebook/presto/execution/TestWarnings.java b/presto-tests/src/test/java/com/facebook/presto/execution/TestWarnings.java index 7c45194f029a4..4659435920f09 100644 --- a/presto-tests/src/test/java/com/facebook/presto/execution/TestWarnings.java +++ b/presto-tests/src/test/java/com/facebook/presto/execution/TestWarnings.java @@ -180,7 +180,7 @@ public void testMapWithDoubleKeysProducesWarnings() assertWarnings(queryRunner, TEST_SESSION, query, ImmutableSet.of(SEMANTIC_WARNING.toWarningCode())); query = "select transform_keys(map(ARRAY [25.5E0, 26.5E0, 27.5E0], ARRAY [25.5E0, 26.5E0, 27.5E0]), (k, v) -> k + v)"; - assertWarnings(queryRunner, TEST_SESSION, query, ImmutableSet.of(SEMANTIC_WARNING.toWarningCode(), PERFORMANCE_WARNING.toWarningCode())); + assertWarnings(queryRunner, TEST_SESSION, query, ImmutableSet.of(SEMANTIC_WARNING.toWarningCode())); query = "SELECT histogram(RETAILPRICE) FROM tpch.tiny.part"; assertWarnings(queryRunner, TEST_SESSION, query, ImmutableSet.of(SEMANTIC_WARNING.toWarningCode()));