diff --git a/core/trino-main/src/main/java/io/trino/operator/WindowInfo.java b/core/trino-main/src/main/java/io/trino/operator/WindowInfo.java index ceaf8933aad0..2f6c778d11ab 100644 --- a/core/trino-main/src/main/java/io/trino/operator/WindowInfo.java +++ b/core/trino-main/src/main/java/io/trino/operator/WindowInfo.java @@ -219,7 +219,10 @@ public Optional build() } double avgSize = partitions.stream().mapToLong(Integer::longValue).average().getAsDouble(); double squaredDifferences = partitions.stream().mapToDouble(size -> Math.pow(size - avgSize, 2)).sum(); - checkState(partitions.stream().mapToLong(Integer::longValue).sum() == rowsNumber, "Total number of rows in index does not match number of rows in partitions within that index"); + if (partitions.stream().mapToLong(Integer::longValue).sum() != rowsNumber) { + // when operator is cancelled, then rows in index might not match row count from processed partitions + return Optional.empty(); + } return Optional.of(new IndexInfo(rowsNumber, sizeInBytes, squaredDifferences, partitions.size())); } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestWindowOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestWindowOperator.java index 0592902dd242..88adf84a5d48 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestWindowOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestWindowOperator.java @@ -17,6 +17,7 @@ import com.google.common.primitives.Ints; import io.airlift.units.DataSize; import io.trino.ExceededMemoryLimitException; +import io.trino.RowPagesBuilder; import io.trino.operator.WindowOperator.WindowOperatorFactory; import io.trino.operator.window.FirstValueFunction; import io.trino.operator.window.FrameInfo; @@ -71,6 +72,7 @@ import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; @Test(singleThreaded = true) @@ -432,6 +434,41 @@ public void testFirstValuePartition(boolean spillEnabled, boolean revokeMemoryWh assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } + @Test + public void testClose() + throws Exception + { + RowPagesBuilder pageBuilder = rowPagesBuilder(VARCHAR, BIGINT); + for (int i = 0; i < 500_000; ++i) { + pageBuilder.row("a", 0L); + } + for (int i = 0; i < 500_000; ++i) { + pageBuilder.row("b", 0L); + } + List input = pageBuilder.build(); + + WindowOperatorFactory operatorFactory = createFactoryUnbounded( + ImmutableList.of(VARCHAR, BIGINT), + Ints.asList(0, 1), + ROW_NUMBER, + Ints.asList(0), + Ints.asList(1), + ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + false); + + DriverContext driverContext = createDriverContext(1000); + Operator operator = operatorFactory.createOperator(driverContext); + operatorFactory.noMoreOperators(); + assertFalse(operator.isFinished()); + assertTrue(operator.needsInput()); + operator.addInput(input.get(0)); + operator.finish(); + operator.getOutput(); + + // this should not fail + operator.close(); + } + @Test(dataProvider = "spillEnabled") public void testLastValuePartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) {