diff --git a/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java index 09795ae62c2ea..da50e15df97fd 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java @@ -367,7 +367,7 @@ public void addInput(Page page) if (aggregationBuilder == null) { // TODO: We ignore spillEnabled here if any aggregate has ORDER BY clause or DISTINCT because they are not yet implemented for spilling. - if (step.isOutputPartial() || !spillEnabled || hasOrderBy() || hasDistinct()) { + if (step.isOutputPartial() || !spillEnabled) { aggregationBuilder = new InMemoryHashAggregationBuilder( accumulatorFactories, step, @@ -411,16 +411,6 @@ public void addInput(Page page) aggregationBuilder.updateMemory(); } - private boolean hasOrderBy() - { - return accumulatorFactories.stream().anyMatch(AccumulatorFactory::hasOrderBy); - } - - private boolean hasDistinct() - { - return accumulatorFactories.stream().anyMatch(AccumulatorFactory::hasDistinct); - } - @Override public ListenableFuture startMemoryRevoke() { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AccumulatorFactoryBinder.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AccumulatorFactoryBinder.java index d8e5127c453d0..963d7428e9ae8 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AccumulatorFactoryBinder.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AccumulatorFactoryBinder.java @@ -34,5 +34,6 @@ AccumulatorFactory bind( boolean distinct, JoinCompiler joinCompiler, List lambdaProviders, + boolean spillEnabled, Session session); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/FinalOnlyGroupedAccumulator.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/FinalOnlyGroupedAccumulator.java new file mode 100644 index 0000000000000..23a1759353dcd --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/FinalOnlyGroupedAccumulator.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.operator.GroupByIdBlock; + +/** + * {@link FinalOnlyGroupedAccumulator} is an accumulator that does not support partial aggregation + * This is for spilling purposes so any underlying accumulator must support spilling + */ +public abstract class FinalOnlyGroupedAccumulator + implements GroupedAccumulator +{ + @Override + public final Type getIntermediateType() + { + throw new UnsupportedOperationException(); + } + + @Override + public final void addIntermediate(GroupByIdBlock groupIdsBlock, Block block) + { + throw new UnsupportedOperationException(); + } + + @Override + public final void evaluateIntermediate(int groupId, BlockBuilder output) + { + throw new UnsupportedOperationException(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/GenericAccumulatorFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/GenericAccumulatorFactory.java index e7d4a31a5b37a..048f8a4ffdead 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/GenericAccumulatorFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/GenericAccumulatorFactory.java @@ -14,10 +14,20 @@ package com.facebook.presto.operator.aggregation; import com.facebook.presto.Session; +import com.facebook.presto.array.ObjectBigArray; import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.ArrayBlock; +import com.facebook.presto.common.block.ArrayBlockBuilder; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.ColumnarArray; +import com.facebook.presto.common.block.ColumnarRow; +import com.facebook.presto.common.block.LongArrayBlock; +import com.facebook.presto.common.block.RowBlock; +import com.facebook.presto.common.block.RowBlockBuilder; import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.RowType; import com.facebook.presto.common.type.Type; import com.facebook.presto.operator.GroupByIdBlock; import com.facebook.presto.operator.MarkDistinctHash; @@ -29,6 +39,7 @@ import com.facebook.presto.sql.gen.JoinCompiler; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; +import org.openjdk.jol.info.ClassLayout; import javax.annotation.Nullable; @@ -40,6 +51,8 @@ import java.util.Optional; import java.util.stream.Collectors; +import static com.facebook.presto.common.block.ColumnarArray.toColumnarArray; +import static com.facebook.presto.common.block.ColumnarRow.toColumnarRow; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.google.common.base.Preconditions.checkArgument; @@ -67,6 +80,7 @@ public class GenericAccumulatorFactory @Nullable private final Session session; private final boolean distinct; + private final boolean spillEnabled; private final PagesIndex.Factory pagesIndexFactory; public GenericAccumulatorFactory( @@ -82,7 +96,8 @@ public GenericAccumulatorFactory( PagesIndex.Factory pagesIndexFactory, JoinCompiler joinCompiler, Session session, - boolean distinct) + boolean distinct, + boolean spillEnabled) { this.stateDescriptors = requireNonNull(stateDescriptors, "stateDescriptors is null"); this.accumulatorConstructor = requireNonNull(accumulatorConstructor, "accumulatorConstructor is null"); @@ -100,6 +115,7 @@ public GenericAccumulatorFactory( this.joinCompiler = joinCompiler; this.session = session; this.distinct = distinct; + this.spillEnabled = spillEnabled; } @Override @@ -113,7 +129,7 @@ public Accumulator createAccumulator() { Accumulator accumulator; - if (distinct) { + if (hasDistinct()) { // channel 0 will contain the distinct mask accumulator = instantiateAccumulator( inputChannels.stream() @@ -151,10 +167,47 @@ public Accumulator createIntermediateAccumulator() @Override public GroupedAccumulator createGroupedAccumulator() + { + GroupedAccumulator accumulator = createGenericGroupedAccumulator(); + if (!spillEnabled || (!hasDistinct() && !hasOrderBy())) { + return accumulator; + } + + checkState(accumulator instanceof FinalOnlyGroupedAccumulator); + return new SpillableFinalOnlyGroupedAccumulator(sourceTypes, (FinalOnlyGroupedAccumulator) accumulator); + } + + @Override + public GroupedAccumulator createGroupedIntermediateAccumulator() + { + if (!hasOrderBy() && !hasDistinct()) { + try { + return groupedAccumulatorConstructor.newInstance(stateDescriptors, ImmutableList.of(), Optional.empty(), lambdaProviders); + } + catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException(e); + } + } + return createGroupedAccumulator(); + } + + @Override + public boolean hasOrderBy() + { + return !orderByChannels.isEmpty(); + } + + @Override + public boolean hasDistinct() + { + return distinct; + } + + private GroupedAccumulator createGenericGroupedAccumulator() { GroupedAccumulator accumulator; - if (distinct) { + if (hasDistinct()) { // channel 0 will contain the distinct mask accumulator = instantiateGroupedAccumulator( inputChannels.stream() @@ -180,29 +233,6 @@ public GroupedAccumulator createGroupedAccumulator() return new OrderingGroupedAccumulator(accumulator, sourceTypes, orderByChannels, orderings, pagesIndexFactory); } - @Override - public GroupedAccumulator createGroupedIntermediateAccumulator() - { - try { - return groupedAccumulatorConstructor.newInstance(stateDescriptors, ImmutableList.of(), Optional.empty(), lambdaProviders); - } - catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { - throw new RuntimeException(e); - } - } - - @Override - public boolean hasOrderBy() - { - return !orderByChannels.isEmpty(); - } - - @Override - public boolean hasDistinct() - { - return distinct; - } - private Accumulator instantiateAccumulator(List inputs, Optional mask) { try { @@ -322,7 +352,7 @@ private static Page filter(Page page, Block mask) } private static class DistinctingGroupedAccumulator - implements GroupedAccumulator + extends FinalOnlyGroupedAccumulator { private final GroupedAccumulator accumulator; private final MarkDistinctHash hash; @@ -365,12 +395,6 @@ public Type getFinalType() return accumulator.getFinalType(); } - @Override - public Type getIntermediateType() - { - throw new UnsupportedOperationException(); - } - @Override public void addInput(GroupByIdBlock groupIdsBlock, Page page) { @@ -399,18 +423,6 @@ public void addInput(GroupByIdBlock groupIdsBlock, Page page) accumulator.addInput(groupIds, new Page(filtered.getPositionCount(), columns)); } - @Override - public void addIntermediate(GroupByIdBlock groupIdsBlock, Block block) - { - throw new UnsupportedOperationException(); - } - - @Override - public void evaluateIntermediate(int groupId, BlockBuilder output) - { - throw new UnsupportedOperationException(); - } - @Override public void evaluateFinal(int groupId, BlockBuilder output) { @@ -497,7 +509,7 @@ public void evaluateFinal(BlockBuilder blockBuilder) } private static class OrderingGroupedAccumulator - implements GroupedAccumulator + extends FinalOnlyGroupedAccumulator { private final GroupedAccumulator accumulator; private final List orderByChannels; @@ -535,12 +547,6 @@ public Type getFinalType() return accumulator.getFinalType(); } - @Override - public Type getIntermediateType() - { - throw new UnsupportedOperationException(); - } - @Override public void addInput(GroupByIdBlock groupIdsBlock, Page page) { @@ -554,18 +560,6 @@ public void addInput(GroupByIdBlock groupIdsBlock, Page page) pagesIndex.addPage(new Page(blocks)); } - @Override - public void addIntermediate(GroupByIdBlock groupIdsBlock, Block block) - { - throw new UnsupportedOperationException(); - } - - @Override - public void evaluateIntermediate(int groupId, BlockBuilder output) - { - throw new UnsupportedOperationException(); - } - @Override public void evaluateFinal(int groupId, BlockBuilder output) { @@ -586,4 +580,179 @@ public void prepareFinal() }); } } + + /** + * {@link SpillableFinalOnlyGroupedAccumulator} enables spilling for {@link FinalOnlyGroupedAccumulator} + */ + private static class SpillableFinalOnlyGroupedAccumulator + implements GroupedAccumulator + { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SpillableFinalOnlyGroupedAccumulator.class).instanceSize(); + + private final FinalOnlyGroupedAccumulator delegate; + private final List spillingTypes; + + private ObjectBigArray rawInputs = new ObjectBigArray<>(); + private ObjectBigArray blockBuilders; + private long rawInputsLength; + + public SpillableFinalOnlyGroupedAccumulator(List types, FinalOnlyGroupedAccumulator delegate) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + this.spillingTypes = requireNonNull(types, "types is null"); + } + + @Override + public long getEstimatedSize() + { + return INSTANCE_SIZE + + delegate.getEstimatedSize() + + (rawInputs == null ? 0 : rawInputs.sizeOf()) + + (blockBuilders == null ? 0 : blockBuilders.sizeOf()); + } + + @Override + public Type getFinalType() + { + return delegate.getFinalType(); + } + + @Override + public Type getIntermediateType() + { + return new ArrayType(RowType.anonymous(spillingTypes)); + } + + @Override + public void addInput(GroupByIdBlock groupIdsBlock, Page page) + { + checkState(rawInputs != null && blockBuilders == null); + rawInputs.ensureCapacity(rawInputsLength); + rawInputs.set(rawInputsLength, new GroupIdPage(groupIdsBlock, page)); + // TODO(sakshams) deduplicate inputs for DISTINCT accumulator case by doing page compaction + rawInputsLength++; + } + + @Override + public void addIntermediate(GroupByIdBlock groupIdsBlock, Block block) + { + checkState(rawInputs != null && blockBuilders == null); + checkState(block instanceof ArrayBlock); + ArrayBlock arrayBlock = (ArrayBlock) block; + + // expand array block back into page + ColumnarArray columnarArray = toColumnarArray(block); // flattens the squashed arrays; so there is no need to flatten block again. + ColumnarRow columnarRow = toColumnarRow(columnarArray.getElementsBlock()); // contains the flattened array + int newPositionCount = columnarRow.getPositionCount(); // number of positions in expanded array (since columnarRow is already flattened) + long[] newGroupIds = new long[newPositionCount]; + boolean[] nulls = new boolean[newPositionCount]; + int currentRowBlockIndex = 0; + for (int groupIdPosition = 0; groupIdPosition < groupIdsBlock.getPositionCount(); groupIdPosition++) { + for (int unused = 0; unused < arrayBlock.getBlock(groupIdPosition).getPositionCount(); unused++) { + // unused because we are expanding all the squashed values for the same group id + newGroupIds[currentRowBlockIndex] = groupIdsBlock.getGroupId(groupIdPosition); + nulls[currentRowBlockIndex] = groupIdsBlock.isNull(groupIdPosition); + currentRowBlockIndex++; + } + } + + Block[] blocks = new Block[spillingTypes.size()]; + for (int channel = 0; channel < spillingTypes.size(); channel++) { + blocks[channel] = columnarRow.getField(channel); + } + Page page = new Page(blocks); + GroupByIdBlock squashedGroupIds = new GroupByIdBlock(groupIdsBlock.getGroupCount(), new LongArrayBlock(newPositionCount, Optional.of(nulls), newGroupIds)); + + rawInputs.ensureCapacity(rawInputsLength); + rawInputs.set(rawInputsLength, new GroupIdPage(squashedGroupIds, page)); + rawInputsLength++; + } + + @Override + public void evaluateIntermediate(int groupId, BlockBuilder output) + { + checkState(output instanceof ArrayBlockBuilder); + if (blockBuilders == null) { + checkState(rawInputs != null); + + blockBuilders = new ObjectBigArray<>(); + for (int i = 0; i < rawInputsLength; i++) { + GroupIdPage groupIdPage = rawInputs.get(i); + Page page = groupIdPage.getPage(); + GroupByIdBlock groupIdsBlock = groupIdPage.getGroupByIdBlock(); + for (int position = 0; position < page.getPositionCount(); position++) { + long currentGroupId = groupIdsBlock.getGroupId(position); + blockBuilders.ensureCapacity(currentGroupId); + RowBlockBuilder rowBlockBuilder = blockBuilders.get(currentGroupId); + if (rowBlockBuilder == null) { + rowBlockBuilder = new RowBlockBuilder(spillingTypes, null, (int) groupIdsBlock.getGroupCount()); + } + + BlockBuilder currentOutput = rowBlockBuilder.beginBlockEntry(); + for (int channel = 0; channel < spillingTypes.size(); channel++) { + spillingTypes.get(channel).appendTo(page.getBlock(channel), position, currentOutput); + } + rowBlockBuilder.closeEntry(); + + blockBuilders.set(currentGroupId, rowBlockBuilder); + } + } + rawInputs = null; + rawInputsLength = 0; + } + + BlockBuilder singleArrayBlockWriter = output.beginBlockEntry(); + checkState(rawInputs == null && blockBuilders != null); + + // We need to squash the entire page into one array block since we can't spill multiple values for a single group ID during evaluateIntermediate. + RowBlock rowBlock = (RowBlock) blockBuilders.get(groupId).build(); + for (int i = 0; i < rowBlock.getPositionCount(); i++) { + singleArrayBlockWriter.appendStructure(rowBlock.getBlock(i)); + } + output.closeEntry(); + } + + @Override + public void evaluateFinal(int groupId, BlockBuilder output) + { + checkState(rawInputs == null && blockBuilders == null); + delegate.evaluateFinal(groupId, output); + } + + @Override + public void prepareFinal() + { + checkState(rawInputs != null && blockBuilders == null); + for (int i = 0; i < rawInputsLength; i++) { + GroupIdPage groupIdPage = rawInputs.get(i); + delegate.addInput(groupIdPage.getGroupByIdBlock(), groupIdPage.getPage()); + } + + rawInputs = null; + rawInputsLength = 0; + delegate.prepareFinal(); + } + + private static class GroupIdPage + { + private final GroupByIdBlock groupByIdBlock; + private final Page page; + + public GroupIdPage(GroupByIdBlock groupByIdBlock, Page page) + { + this.page = requireNonNull(page, "page is null"); + this.groupByIdBlock = requireNonNull(groupByIdBlock, "groupByIdBlock is null"); + } + + public Page getPage() + { + return page; + } + + public GroupByIdBlock getGroupByIdBlock() + { + return groupByIdBlock; + } + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/GenericAccumulatorFactoryBinder.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/GenericAccumulatorFactoryBinder.java index d475f1e2dd8ef..826e99e93b8c5 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/GenericAccumulatorFactoryBinder.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/GenericAccumulatorFactoryBinder.java @@ -70,6 +70,7 @@ public AccumulatorFactory bind( boolean distinct, JoinCompiler joinCompiler, List lambdaProviders, + boolean spillEnabled, Session session) { return new GenericAccumulatorFactory( @@ -85,7 +86,8 @@ public AccumulatorFactory bind( pagesIndexFactory, joinCompiler, session, - distinct); + distinct, + spillEnabled); } @VisibleForTesting diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/InternalAggregationFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/InternalAggregationFunction.java index 13f87998f2f6d..da2aee78c2355 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/InternalAggregationFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/InternalAggregationFunction.java @@ -139,6 +139,7 @@ public AccumulatorFactory bind(List inputChannels, Optional ma false, null, ImmutableList.of(), + false, null); } @@ -152,9 +153,10 @@ public AccumulatorFactory bind( boolean distinct, JoinCompiler joinCompiler, List lambdaProviders, + boolean spillEnabled, Session session) { - return factory.bind(inputChannels, maskChannel, sourceTypes, orderByChannels, orderings, pagesIndexFactory, distinct, joinCompiler, lambdaProviders, session); + return factory.bind(inputChannels, maskChannel, sourceTypes, orderByChannels, orderings, pagesIndexFactory, distinct, joinCompiler, lambdaProviders, spillEnabled, session); } @VisibleForTesting diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/LazyAccumulatorFactoryBinder.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/LazyAccumulatorFactoryBinder.java index c55b4b9466d64..018d2f4bde7d8 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/LazyAccumulatorFactoryBinder.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/LazyAccumulatorFactoryBinder.java @@ -53,8 +53,20 @@ public AccumulatorFactory bind( boolean distinct, JoinCompiler joinCompiler, List lambdaProviders, + boolean spillEnabled, Session session) { - return binder.get().bind(argumentChannels, maskChannel, sourceTypes, orderByChannels, orderings, pagesIndexFactory, distinct, joinCompiler, lambdaProviders, session); + return binder.get().bind( + argumentChannels, + maskChannel, + sourceTypes, + orderByChannels, + orderings, + pagesIndexFactory, + distinct, + joinCompiler, + lambdaProviders, + spillEnabled, + session); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java index a7371925b838d..09790b821ebe4 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java @@ -66,6 +66,7 @@ public class InMemoryHashAggregationBuilder private final boolean useSystemMemory; private boolean full; + private boolean hasBuiltFinalResult; public InMemoryHashAggregationBuilder( List accumulatorFactories, @@ -253,6 +254,7 @@ public long getGroupCount() @Override public WorkProcessor buildResult() { + hasBuiltFinalResult = true; for (Aggregator aggregator : aggregators) { aggregator.prepareFinal(); } @@ -273,6 +275,11 @@ public List buildIntermediateTypes() return types; } + public boolean hasBuiltFinalResult() + { + return hasBuiltFinalResult; + } + @VisibleForTesting public int getCapacity() { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java index 0cec97eb5e991..ba8624bcb7c78 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java @@ -155,6 +155,11 @@ private boolean hasPreviousSpillCompletedSuccessfully() public ListenableFuture startMemoryRevoke() { checkState(spillInProgress.isDone()); + if (hashAggregationBuilder.hasBuiltFinalResult()) { + // If the hashAggregationBuilder has already completed, decline memory revoking. At this point, buildResult has already been called + // on InMemoryHashAggregationBuilder and it is no longer accepting any input so no point in spilling. + return spillInProgress; + } spillToDisk(); return spillInProgress; } @@ -162,6 +167,10 @@ public ListenableFuture startMemoryRevoke() @Override public void finishMemoryRevoke() { + if (hashAggregationBuilder.hasBuiltFinalResult()) { + // Do not update memory if we never spilt during startMemoryRevoke if hashAggregationBuilder has already built it's final result + return; + } updateMemory(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index a3055ca729446..803b7cff58666 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -2704,7 +2704,8 @@ private List getVariableTypes(List variables) private AccumulatorFactory buildAccumulatorFactory( PhysicalOperation source, - Aggregation aggregation) + Aggregation aggregation, + boolean spillEnabled) { FunctionManager functionManager = metadata.getFunctionManager(); InternalAggregationFunction internalAggregationFunction = functionManager.getAggregateFunctionImplementation(aggregation.getFunctionHandle()); @@ -2752,6 +2753,7 @@ private AccumulatorFactory buildAccumulatorFactory( aggregation.isDistinct(), joinCompiler, lambdaProviders, + spillEnabled, session); } @@ -2785,7 +2787,7 @@ private AggregationOperatorFactory createAggregationOperatorFactory( for (Map.Entry entry : aggregations.entrySet()) { VariableReferenceExpression variable = entry.getKey(); Aggregation aggregation = entry.getValue(); - accumulatorFactories.add(buildAccumulatorFactory(source, aggregation)); + accumulatorFactories.add(buildAccumulatorFactory(source, aggregation, false)); outputMappings.put(variable, outputChannel); // one aggregation per channel outputChannel++; } @@ -2848,7 +2850,7 @@ private OperatorFactory createHashAggregationOperatorFactory( VariableReferenceExpression variable = entry.getKey(); Aggregation aggregation = entry.getValue(); - accumulatorFactories.add(buildAccumulatorFactory(source, aggregation)); + accumulatorFactories.add(buildAccumulatorFactory(source, aggregation, !isStreamable && spillEnabled)); aggregationOutputVariables.add(variable); } diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestDistributedSpilledQueries.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestDistributedSpilledQueries.java index 74f77a77a50a0..e4f7af64984c7 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/TestDistributedSpilledQueries.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestDistributedSpilledQueries.java @@ -40,6 +40,7 @@ public static DistributedQueryRunner createQueryRunner() .setSystemProperty(SystemSessionProperties.TASK_CONCURRENCY, "2") .setSystemProperty(SystemSessionProperties.SPILL_ENABLED, "true") .setSystemProperty(SystemSessionProperties.AGGREGATION_OPERATOR_UNSPILL_MEMORY_LIMIT, "128kB") + .setSystemProperty(SystemSessionProperties.USE_MARK_DISTINCT, "false") .build(); ImmutableMap extraProperties = ImmutableMap.builder() diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestSpilledAggregations.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestSpilledAggregations.java index b0d59791ccd85..f2a5fb762c46b 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/TestSpilledAggregations.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestSpilledAggregations.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.tests; +import org.testng.annotations.Test; + public class TestSpilledAggregations extends AbstractTestAggregations { @@ -20,4 +22,56 @@ public TestSpilledAggregations() { super(TestDistributedSpilledQueries::createQueryRunner); } + + @Test + public void OrderBySpillingBasic() + { + assertQuery("SELECT orderpriority, custkey, array_agg(orderstatus ORDER BY orderstatus) FROM orders GROUP BY orderpriority, custkey ORDER BY 1, 2"); + } + + @Test + public void OrderBySpillingGroupingSets() + { + assertQuery( + "SELECT orderpriority, custkey, array_agg(orderstatus ORDER BY orderstatus) FROM orders WHERE orderkey IN (1, 2, 3, 4, 5) " + + "GROUP BY GROUPING SETS ((), (orderpriority), (orderpriority, custkey))", + "SELECT NULL, NULL, array_agg(orderstatus ORDER BY orderstatus) FROM orders WHERE orderkey IN (1, 2, 3, 4, 5) UNION ALL " + + "SELECT orderpriority, NULL, array_agg(orderstatus ORDER BY orderstatus) FROM orders WHERE orderkey IN (1, 2, 3, 4, 5) GROUP BY orderpriority UNION ALL " + + "SELECT orderpriority, custkey, array_agg(orderstatus ORDER BY orderstatus) FROM orders WHERE orderkey IN (1, 2, 3, 4, 5) GROUP BY orderpriority, custkey"); + } + + @Test + public void DistinctSpillingBasic() + { + // the sum() is necessary so that the aggregation isn't optimized into multiple aggregation nodes + assertQuery("SELECT custkey, sum(custkey), count(DISTINCT orderpriority) FILTER(WHERE orderkey > 5) FROM orders GROUP BY custkey ORDER BY 1"); + } + + @Test + public void DistinctAndOrderBySpillingBasic() + { + assertQuery("SELECT custkey, orderpriority, sum(custkey), array_agg(DISTINCT orderpriority ORDER BY orderpriority) FROM orders GROUP BY custkey, orderpriority ORDER BY 1, 2"); + } + + @Test + public void DistinctSpillingCount() + { + assertQuery("SELECT orderpriority, custkey, sum(custkey), count(DISTINCT totalprice) FROM orders GROUP BY orderpriority, custkey ORDER BY 1, 2"); + } + + @Test + public void DistinctSpillingGroupingSets() + { + assertQuery( + "SELECT custkey, count(DISTINCT orderpriority) FROM orders WHERE orderkey IN (1, 2, 3, 4, 5) " + + "GROUP BY GROUPING SETS ((), (custkey))", + "SELECT NULL, count(DISTINCT orderpriority) FROM orders WHERE orderkey IN (1, 2, 3, 4, 5) UNION ALL " + + "SELECT custkey, count(DISTINCT orderpriority) FROM orders WHERE orderkey IN (1, 2, 3, 4, 5) GROUP BY custkey"); + } + + @Test + public void TestNonGroupedOrderBySpill() + { + assertQuery("SELECT array_agg(orderstatus ORDER BY orderstatus) FROM orders"); + } }