diff --git a/core/trino-main/src/main/java/io/trino/operator/TopNOperator.java b/core/trino-main/src/main/java/io/trino/operator/TopNOperator.java index 95be5d8c0e13..aed9598f8222 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TopNOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/TopNOperator.java @@ -15,9 +15,8 @@ import com.google.common.collect.ImmutableList; import io.trino.memory.context.MemoryTrackingContext; +import io.trino.operator.BasicWorkProcessorOperatorAdapter.BasicAdapterWorkProcessorOperatorFactory; import io.trino.operator.WorkProcessor.TransformationState; -import io.trino.operator.WorkProcessorOperatorAdapter.AdapterWorkProcessorOperator; -import io.trino.operator.WorkProcessorOperatorAdapter.AdapterWorkProcessorOperatorFactory; import io.trino.spi.Page; import io.trino.spi.connector.SortOrder; import io.trino.spi.type.Type; @@ -25,17 +24,16 @@ import io.trino.sql.planner.plan.PlanNodeId; import java.util.List; -import java.util.Optional; import static com.google.common.base.Preconditions.checkState; -import static io.trino.operator.WorkProcessorOperatorAdapter.createAdapterOperatorFactory; +import static io.trino.operator.BasicWorkProcessorOperatorAdapter.createAdapterOperatorFactory; import static java.util.Objects.requireNonNull; /** * Returns the top N rows from the source sorted according to the specified ordering in the keyChannelIndex channel. */ public class TopNOperator - implements AdapterWorkProcessorOperator + implements WorkProcessorOperator { public static OperatorFactory createOperatorFactory( int operatorId, @@ -50,7 +48,7 @@ public static OperatorFactory createOperatorFactory( } private static class Factory - implements AdapterWorkProcessorOperatorFactory + implements BasicAdapterWorkProcessorOperatorFactory { private final int operatorId; private final PlanNodeId planNodeId; @@ -87,21 +85,7 @@ public WorkProcessorOperator create( checkState(!closed, "Factory is already closed"); return new TopNOperator( processorContext.getMemoryTrackingContext(), - Optional.of(sourcePages), - sourceTypes, - n, - sortChannels, - sortOrders, - typeOperators); - } - - @Override - public AdapterWorkProcessorOperator createAdapterOperator(ProcessorContext processorContext) - { - checkState(!closed, "Factory is already closed"); - return new TopNOperator( - processorContext.getMemoryTrackingContext(), - Optional.empty(), + sourcePages, sourceTypes, n, sortChannels, @@ -142,11 +126,10 @@ public Factory duplicate() private final TopNProcessor topNProcessor; private final WorkProcessor pages; - private final PageBuffer pageBuffer = new PageBuffer(); private TopNOperator( MemoryTrackingContext memoryTrackingContext, - Optional> sourcePages, + WorkProcessor sourcePages, List types, int n, List sortChannels, @@ -165,7 +148,7 @@ private TopNOperator( pages = WorkProcessor.of(); } else { - pages = sourcePages.orElse(pageBuffer.pages()).transform(new TopNPages()); + pages = sourcePages.transform(new TopNPages()); } } @@ -175,29 +158,6 @@ public WorkProcessor getOutputPages() return pages; } - @Override - public boolean needsInput() - { - return pageBuffer.isEmpty() && !pageBuffer.isFinished(); - } - - @Override - public void addInput(Page page) - { - addPage(page); - } - - @Override - public void finish() - { - pageBuffer.finish(); - } - - private void addPage(Page page) - { - topNProcessor.addInput(page); - } - private class TopNPages implements WorkProcessor.Transformation { @@ -205,7 +165,7 @@ private class TopNPages public TransformationState process(Page inputPage) { if (inputPage != null) { - addPage(inputPage); + topNProcessor.addInput(inputPage); return TransformationState.needsMoreData(); } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestTopNOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestTopNOperator.java index dfef2508c3cd..d6629f1fc34a 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestTopNOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestTopNOperator.java @@ -199,7 +199,8 @@ public void testExceedMemoryLimit() ImmutableList.of(0), ImmutableList.of(ASC_NULLS_LAST)); Operator operator = operatorFactory.createOperator(smallDiverContext); - assertThatThrownBy(() -> operator.addInput(input.get(0))) + operator.addInput(input.get(0)); + assertThatThrownBy(() -> operator.getOutput()) .isInstanceOf(ExceededMemoryLimitException.class) .hasMessageStartingWith("Query exceeded per-node memory limit of "); }