diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MergePartialAggregationsWithFilter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MergePartialAggregationsWithFilter.java index 1c6b974853eda..3db7dc66cc6e4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MergePartialAggregationsWithFilter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MergePartialAggregationsWithFilter.java @@ -36,12 +36,16 @@ import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; import static com.facebook.presto.SystemSessionProperties.isMergeAggregationsWithAndWithoutFilter; +import static com.facebook.presto.expressions.LogicalRowExpressions.or; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL; import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL; @@ -123,11 +127,13 @@ private static class Context { private final Map partialResultToMask; private final Map partialOutputMapping; + private final List newAggregationOutput; public Context() { partialResultToMask = new HashMap<>(); partialOutputMapping = new HashMap<>(); + newAggregationOutput = new LinkedList<>(); } public boolean isEmpty() @@ -139,6 +145,7 @@ public void clear() { partialResultToMask.clear(); partialOutputMapping.clear(); + newAggregationOutput.clear(); } public Map getPartialOutputMapping() @@ -150,6 +157,11 @@ public Map getPartialR { return partialResultToMask; } + + public List getNewAggregationOutput() + { + return newAggregationOutput; + } } private static class Rewriter @@ -218,17 +230,60 @@ else if (node.getStep().equals(FINAL)) { private AggregationNode createPartialAggregationNode(AggregationNode node, PlanNode rewrittenSource, RewriteContext context) { checkState(context.get().isEmpty(), "There should be no partial aggregation left unmerged for a partial aggregation node"); + Map aggregationsWithoutMaskToOutput = node.getAggregations().entrySet().stream() .filter(x -> !x.getValue().getMask().isPresent()) - .collect(toImmutableMap(x -> x.getValue(), x -> x.getKey(), (a, b) -> a)); + .collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey, (a, b) -> a)); Map aggregationsToMergeOutput = node.getAggregations().entrySet().stream() .filter(x -> x.getValue().getMask().isPresent() && aggregationsWithoutMaskToOutput.containsKey(removeFilterAndMask(x.getValue()))) - .collect(toImmutableMap(x -> x.getValue(), x -> x.getKey())); + .collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey)); + + ImmutableMap.Builder partialAggregationToOutputBuilder = ImmutableMap.builder(); + partialAggregationToOutputBuilder.putAll(aggregationsToMergeOutput.keySet().stream().collect(toImmutableMap(Function.identity(), x -> aggregationsWithoutMaskToOutput.get(removeFilterAndMask(x))))); + + List> candidateAggregationsWithMaskNotMatched = node.getAggregations().entrySet().stream().map(Map.Entry::getValue) + .filter(x -> x.getMask().isPresent() && !aggregationsToMergeOutput.containsKey(x)) + .collect(Collectors.groupingBy(AggregationNodeUtils::removeFilterAndMask)).values() + .stream().filter(x -> x.size() > 1).collect(toImmutableList()); + + Map aggregationsWithMaskToMerge = node.getAggregations().entrySet().stream() + .filter(x -> aggregationsToMergeOutput.containsKey(x.getValue()) || candidateAggregationsWithMaskNotMatched.stream().anyMatch(aggregations -> aggregations.contains(x.getValue()))) + .collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey)); + ImmutableMap.Builder newMaskAssignmentsBuilder = ImmutableMap.builder(); + ImmutableMap.Builder aggregationsAddedBuilder = ImmutableMap.builder(); + List newAggregationAdded = candidateAggregationsWithMaskNotMatched.stream() + .map(aggregations -> + { + List maskVariables = aggregations.stream().map(x -> x.getMask().get()).collect(toImmutableList()); + RowExpression orMaskVariables = or(maskVariables); + VariableReferenceExpression newMaskVariable = variableAllocator.newVariable(orMaskVariables); + newMaskAssignmentsBuilder.put(newMaskVariable, orMaskVariables); + AggregationNode.Aggregation newAggregation = new AggregationNode.Aggregation( + aggregations.get(0).getCall(), + Optional.empty(), + aggregations.get(0).getOrderBy(), + aggregations.get(0).isDistinct(), + Optional.of(newMaskVariable)); + VariableReferenceExpression newAggregationVariable = variableAllocator.newVariable(newAggregation.getCall()); + aggregationsAddedBuilder.put(newAggregationVariable, newAggregation); + aggregations.forEach(x -> partialAggregationToOutputBuilder.put(x, newAggregationVariable)); + return newAggregation; + }) + .collect(toImmutableList()); + Map newMaskAssignments = newMaskAssignmentsBuilder.build(); + Map aggregationsAdded = aggregationsAddedBuilder.build(); + Map partialAggregationToOutput = partialAggregationToOutputBuilder.build(); + + Map aggregationsToMergeOutputCombined = + node.getAggregations().entrySet().stream() + .filter(x -> x.getValue().getMask().isPresent() && aggregationsToMergeOutput.containsKey(x.getValue()) || candidateAggregationsWithMaskNotMatched.stream().anyMatch(aggregations -> aggregations.contains(x.getValue()))) + .collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey)); - context.get().getPartialResultToMask().putAll(aggregationsToMergeOutput.entrySet().stream() - .collect(toImmutableMap(x -> x.getValue(), x -> x.getKey().getMask().get()))); - context.get().getPartialOutputMapping().putAll(aggregationsToMergeOutput.entrySet().stream() - .collect(toImmutableMap(x -> x.getValue(), x -> aggregationsWithoutMaskToOutput.get(removeFilterAndMask(x.getKey()))))); + context.get().getNewAggregationOutput().addAll(aggregationsAdded.keySet()); + context.get().getPartialResultToMask().putAll(aggregationsWithMaskToMerge.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getValue, x -> x.getKey().getMask().get()))); + context.get().getPartialOutputMapping().putAll(aggregationsWithMaskToMerge.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getValue, x -> partialAggregationToOutput.get(x.getKey())))); Set maskVariables = new HashSet<>(context.get().getPartialResultToMask().values()); if (maskVariables.isEmpty()) { @@ -242,14 +297,21 @@ private AggregationNode createPartialAggregationNode(AggregationNode node, PlanN AggregationNode.GroupingSetDescriptor partialGroupingSetDescriptor = new AggregationNode.GroupingSetDescriptor( groupingVariables.build(), groupingSetDescriptor.getGroupingSetCount(), groupingSetDescriptor.getGlobalGroupingSets()); - Set partialResultToMerge = new HashSet<>(aggregationsToMergeOutput.values()); - Map newAggregations = node.getAggregations().entrySet().stream() + Set partialResultToMerge = new HashSet<>(aggregationsToMergeOutputCombined.values()); + Map aggregationsRemained = node.getAggregations().entrySet().stream() .filter(x -> !partialResultToMerge.contains(x.getKey())).collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + Map newAggregations = ImmutableMap.builder() + .putAll(aggregationsRemained).putAll(aggregationsAdded).build(); + + PlanNode newChild = rewrittenSource; + if (!newMaskAssignments.isEmpty()) { + newChild = addProjections(newChild, planNodeIdAllocator, newMaskAssignments); + } return new AggregationNode( node.getSourceLocation(), node.getId(), - rewrittenSource, + newChild, newAggregations, partialGroupingSetDescriptor, node.getPreGroupedVariables(), @@ -265,7 +327,7 @@ private AggregationNode createFinalAggregationNode(AggregationNode node, PlanNod return (AggregationNode) node.replaceChildren(ImmutableList.of(rewrittenSource)); } List intermediateVariables = node.getAggregations().values().stream() - .map(x -> (VariableReferenceExpression) x.getArguments().get(0)).collect(Collectors.toList()); + .map(x -> (VariableReferenceExpression) x.getArguments().get(0)).collect(toImmutableList()); checkState(intermediateVariables.containsAll(context.get().partialResultToMask.keySet())); ImmutableList.Builder projectionsFromPartialAgg = ImmutableList.builder(); @@ -331,6 +393,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext context) .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); assignments.putAll(excludeMergedAssignments); assignments.putAll(identityAssignments(context.get().getPartialResultToMask().values())); + assignments.putAll(identityAssignments(context.get().getNewAggregationOutput())); return new ProjectNode( node.getSourceLocation(), node.getId(), diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergePartialAggregationsWithFilter.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergePartialAggregationsWithFilter.java index 7d61a00af47ef..568e3c6ee38a2 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergePartialAggregationsWithFilter.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergePartialAggregationsWithFilter.java @@ -87,6 +87,37 @@ public void testOptimizationApplied() false); } + @Test + public void testOptimizationAppliedAllHasMask() + { + assertPlan("SELECT partkey, sum(quantity) filter (where orderkey > 10), sum(quantity) filter (where orderkey > 0) from lineitem group by partkey", + enableOptimization(), + anyTree( + aggregation( + singleGroupingSet("partkey"), + ImmutableMap.of(Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum")), + Optional.of("maskFinalSum2"), functionCall("sum", ImmutableList.of("maskPartialSum2"))), + ImmutableMap.of(), + Optional.empty(), + AggregationNode.Step.FINAL, + project( + ImmutableMap.of("maskPartialSum", expression("IF(expr, partialSum, null)"), + "maskPartialSum2", expression("IF(expr2, partialSum, null)")), + anyTree( + aggregation( + singleGroupingSet("partkey", "expr", "expr2"), + ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity"))), + ImmutableMap.of(new Symbol("partialSum"), new Symbol("expr_or")), + Optional.empty(), + AggregationNode.Step.PARTIAL, + project( + ImmutableMap.of("expr_or", expression("expr or expr2")), + project( + ImmutableMap.of("expr", expression("orderkey > 0"), "expr2", expression("orderkey >10")), + tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity"))))))))), + false); + } + @Test public void testOptimizationDisabled() { @@ -188,6 +219,57 @@ public void testAggregationsMultipleLevel() false); } + @Test + public void testAggregationsMultipleLevelAllAggWithMask() + { + assertPlan("select partkey, avg(sum) filter (where suppkey > 10), avg(sum) filter (where suppkey > 0), avg(filtersum) from (select partkey, suppkey, sum(quantity) filter (where orderkey > 10) sum, sum(quantity) filter (where orderkey > 0) filtersum from lineitem group by partkey, suppkey) t group by partkey", + enableOptimization(), + anyTree( + aggregation( + singleGroupingSet("partkey"), + ImmutableMap.of(Optional.of("finalAvg"), functionCall("avg", ImmutableList.of("maskPartialAvg_g10")), Optional.of("maskFinalAvg"), functionCall("avg", ImmutableList.of("maskPartialAvg")), + Optional.of("finalFilterAvg"), functionCall("avg", ImmutableList.of("partialFilterAvg"))), + ImmutableMap.of(), + Optional.empty(), + AggregationNode.Step.FINAL, + project( + ImmutableMap.of("maskPartialAvg", expression("IF(expr_2, partialAvg, null)"), + "maskPartialAvg_g10", expression("IF(expr_2_g10, partialAvg, null)")), + anyTree( + aggregation( + singleGroupingSet("partkey", "expr_2", "expr_2_g10"), + ImmutableMap.of(Optional.of("partialAvg"), functionCall("avg", ImmutableList.of("finalSum_g10")), Optional.of("partialFilterAvg"), functionCall("avg", ImmutableList.of("maskFinalSum"))), + ImmutableMap.of(new Symbol("partialAvg"), new Symbol("expr_2_or")), + Optional.empty(), + AggregationNode.Step.PARTIAL, + project( + ImmutableMap.of("expr_2_or", expression("expr_2 or expr_2_g10")), + project( + ImmutableMap.of("expr_2", expression("suppkey > 0"), "expr_2_g10", expression("suppkey > 10")), + aggregation( + singleGroupingSet("partkey", "suppkey"), + ImmutableMap.of(Optional.of("finalSum_g10"), functionCall("sum", ImmutableList.of("maskPartialSum_g10")), Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum"))), + ImmutableMap.of(), + Optional.empty(), + AggregationNode.Step.FINAL, + project( + ImmutableMap.of("maskPartialSum", expression("IF(expr, partialSum, null)"), + "maskPartialSum_g10", expression("IF(expr_g10, partialSum, null)")), + anyTree( + aggregation( + singleGroupingSet("partkey", "suppkey", "expr", "expr_g10"), + ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity"))), + ImmutableMap.of(new Symbol("partialSum"), new Symbol("expr_or")), + Optional.empty(), + AggregationNode.Step.PARTIAL, + project( + ImmutableMap.of("expr_or", expression("expr or expr_g10")), + project( + ImmutableMap.of("expr", expression("orderkey > 0"), "expr_g10", expression("orderkey > 10")), + tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity", "suppkey", "suppkey"))))))))))))))), + false); + } + @Test public void testGlobalOptimization() { diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index ad37f43f6a0c4..6befe1311ef13 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -6861,6 +6861,51 @@ public void testSameAggregationWithAndWithoutFilter() resultWithOptimization = computeActual(enableOptimization, sql); resultWithoutOptimization = computeActual(disableOptimization, sql); assertEqualsIgnoreOrder(resultWithOptimization, resultWithoutOptimization); + + // Now we do not have aggregations which has no filter + // multiple aggregations in query + sql = "select partkey, sum(quantity) filter (where discount > 0.05), sum(quantity) filter (where discount < 0.05), sum(linenumber) filter (where discount > 0.05), sum(linenumber) filter (where discount < 0.05) from lineitem group by partkey"; + resultWithOptimization = computeActual(enableOptimization, sql); + resultWithoutOptimization = computeActual(disableOptimization, sql); + assertEqualsIgnoreOrder(resultWithOptimization, resultWithoutOptimization); + sql = "select partkey, sum(quantity) filter (where discount > 0.05), sum(quantity) filter (where discount < 0.05), sum(linenumber), sum(linenumber) filter (where discount < 0.05) from lineitem group by partkey"; + resultWithOptimization = computeActual(enableOptimization, sql); + resultWithoutOptimization = computeActual(disableOptimization, sql); + assertEqualsIgnoreOrder(resultWithOptimization, resultWithoutOptimization); + // aggregations in multiple levels + sql = "select partkey, avg(sum) filter (where tax > 0.05), avg(sum) filter (where tax < 0.05), avg(filtersum) from (select partkey, suppkey, sum(quantity) sum, sum(quantity) filter (where discount > 0.05) filtersum, max(tax) tax from lineitem where partkey=1598 group by partkey, suppkey) t group by partkey"; + resultWithOptimization = computeActual(enableOptimization, sql); + resultWithoutOptimization = computeActual(disableOptimization, sql); + assertEqualsIgnoreOrder(resultWithOptimization, resultWithoutOptimization); + sql = "select partkey, avg(sum) filter (where tax > 0.05), avg(sum) filter (where tax < 0.05), avg(filtersum) from (select partkey, suppkey, sum(quantity) filter (where discount < 0.05) sum, sum(quantity) filter (where discount > 0.05) filtersum, max(tax) tax from lineitem where partkey=1598 group by partkey, suppkey) t group by partkey"; + resultWithOptimization = computeActual(enableOptimization, sql); + resultWithoutOptimization = computeActual(disableOptimization, sql); + assertEqualsIgnoreOrder(resultWithOptimization, resultWithoutOptimization); + // global aggregation + sql = "select sum(quantity) filter (where discount > 0.05), sum(quantity) filter (where discount < 0.05) from lineitem"; + resultWithOptimization = computeActual(enableOptimization, sql); + resultWithoutOptimization = computeActual(disableOptimization, sql); + assertEqualsIgnoreOrder(resultWithOptimization, resultWithoutOptimization); + // order by + sql = "select partkey, array_agg(suppkey order by suppkey) filter (where discount < 0.05), array_agg(suppkey order by suppkey) filter (where discount > 0.05) from lineitem group by partkey"; + resultWithOptimization = computeActual(enableOptimization, sql); + resultWithoutOptimization = computeActual(disableOptimization, sql); + assertEqualsIgnoreOrder(resultWithOptimization, resultWithoutOptimization); + // grouping sets + sql = "SELECT partkey, suppkey, sum(quantity) filter (where discount < 0.05), sum(quantity) filter (where discount > 0.05) from lineitem group by grouping sets((), (partkey), (partkey, suppkey))"; + resultWithOptimization = computeActual(enableOptimization, sql); + resultWithoutOptimization = computeActual(disableOptimization, sql); + assertEqualsIgnoreOrder(resultWithOptimization, resultWithoutOptimization); + // aggregation over union + sql = "SELECT partkey, sum(quantity) filter (where orderkey > 10), sum(quantity) filter (where orderkey > 0) from (select quantity, orderkey, partkey from lineitem union all select totalprice as quantity, orderkey, custkey as partkey from orders) group by partkey"; + resultWithOptimization = computeActual(enableOptimization, sql); + resultWithoutOptimization = computeActual(disableOptimization, sql); + assertEqualsIgnoreOrder(resultWithOptimization, resultWithoutOptimization); + // aggregation over join + sql = "select custkey, sum(quantity) filter (where tax > 0.05), sum(quantity) filter (where tax < 0.05) from lineitem l join orders o on l.orderkey=o.orderkey group by custkey"; + resultWithOptimization = computeActual(enableOptimization, sql); + resultWithoutOptimization = computeActual(disableOptimization, sql); + assertEqualsIgnoreOrder(resultWithOptimization, resultWithoutOptimization); } @Test